Merge branch 'refs/heads/develop' into feature/agent-tool_xjn

# Conflicts:
#	api/app/core/agent/langchain_agent.py
#	api/app/core/tools/mcp/client.py
This commit is contained in:
Timebomb2018
2026-04-01 15:27:34 +08:00
219 changed files with 4861 additions and 2599 deletions

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService
from app.schemas import FileType
logger = get_business_logger()
@@ -43,18 +44,17 @@ class AppChatService:
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
user_id: Optional[str] = None,
files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
# 应用 features 配置
features_config: dict = config.features or {}
@@ -93,7 +93,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools)
memory_flag = False
if memory:
@@ -140,13 +141,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables)
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(
@@ -168,11 +169,6 @@ class AppChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件
)
@@ -229,6 +225,21 @@ class AppChatService:
# 保存消息
if audio_url:
assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
@@ -264,20 +275,19 @@ class AppChatService:
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None
workspace_id: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
try:
start_time = time.time()
config_id = None
message_id = uuid.uuid4()
# 应用 features 配置
@@ -319,7 +329,8 @@ class AppChatService:
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
@@ -367,13 +378,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables)
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(
@@ -411,11 +422,6 @@ class AppChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files
):
if isinstance(chunk, int):
@@ -459,7 +465,7 @@ class AppChatService:
# 保存消息
human_meta = {
"files":[],
"files": [],
"history_files": {}
}
assistant_meta = {
@@ -484,6 +490,22 @@ class AppChatService:
if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
@@ -618,7 +640,6 @@ class AppChatService:
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=message,

View File

@@ -0,0 +1,128 @@
"""应用日志服务层"""
import uuid
from typing import Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.models.conversation_model import Conversation, Message
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
logger = get_business_logger()
class AppLogService:
"""应用日志服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_repository = ConversationRepository(db)
self.message_repository = MessageRepository(db)
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
is_draft: Optional[bool] = None,
) -> Tuple[list[Conversation], int]:
"""
查询应用日志会话列表
Args:
app_id: 应用 ID
workspace_id: 工作空间 ID
page: 页码(从 1 开始)
pagesize: 每页数量
is_draft: 是否草稿会话None 表示不过滤)
Returns:
Tuple[list[Conversation], int]: (会话列表,总数)
"""
logger.info(
"查询应用日志会话列表",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"page": page,
"pagesize": pagesize,
"is_draft": is_draft
}
)
# 使用 Repository 查询
conversations, total = self.conversation_repository.list_app_conversations(
app_id=app_id,
workspace_id=workspace_id,
is_draft=is_draft,
page=page,
pagesize=pagesize
)
logger.info(
"查询应用日志会话列表成功",
extra={
"app_id": str(app_id),
"total": total,
"returned": len(conversations)
}
)
return conversations, total
def get_conversation_detail(
self,
app_id: uuid.UUID,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Conversation:
"""
查询会话详情(包含消息)
Args:
app_id: 应用 ID
conversation_id: 会话 ID
workspace_id: 工作空间 ID
Returns:
Conversation: 包含消息的会话对象
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)
}
)
# 查询会话
conversation = self.conversation_repository.get_conversation_for_app_log(
conversation_id=conversation_id,
app_id=app_id,
workspace_id=workspace_id
)
# 查询消息(按时间正序)
messages = self.message_repository.get_messages_by_conversation(
conversation_id=conversation_id
)
# 将消息附加到会话对象
conversation.messages = messages
logger.info(
"查询应用日志会话详情成功",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return conversation

View File

@@ -1084,7 +1084,6 @@ class AppService:
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned
exists = self.db.query(
@@ -1096,7 +1095,6 @@ class AppService:
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned
@@ -1684,15 +1682,15 @@ class AppService:
return config.config_id
def _update_endusers_memory_config_by_workspace(
def _update_endusers_memory_config_by_app(
self,
workspace_id: uuid.UUID,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
workspace_id: 工作空间ID
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
@@ -1701,8 +1699,8 @@ class AppService:
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id_by_workspace(
workspace_id=workspace_id,
updated_count = repo.batch_update_memory_config_id_by_app(
app_id=app_id,
memory_config_id=memory_config_id
)
@@ -1753,12 +1751,16 @@ class AppService:
miss_params = []
if agent_cfg.default_model_config_id is None:
miss_params.append("model config")
miss_params.append("模型配置")
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
miss_params.append("memory config")
miss_params.append("记忆配置")
if miss_params:
raise BusinessException(f"{', '.join(miss_params)} is required")
raise BusinessException(
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
BizCode.CONFIG_MISSING,
context={"missing_params": miss_params},
)
config = {
"system_prompt": agent_cfg.system_prompt,
@@ -1877,8 +1879,8 @@ class AppService:
if memory_config_id:
app = self.db.query(App).filter(App.id == app_id).first()
if app:
updated_count = self._update_endusers_memory_config_by_workspace(
app.workspace_id, memory_config_id
updated_count = self._update_endusers_memory_config_by_app(
app_id, memory_config_id
)
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
@@ -2014,7 +2016,7 @@ class AppService:
if memory_config_id:
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
logger.info(
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"

View File

@@ -214,7 +214,7 @@ class ConversationService:
conversation.message_count += 1
if conversation.message_count == 1 and role == "user":
if conversation.message_count <= 2 and role == "user":
conversation.title = (
content[:50] + ("..." if len(content) > 50 else "")
)

View File

@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.schemas import FileType
logger = get_business_logger()
@@ -449,15 +448,16 @@ class AgentRunService:
features_config: Dict[str, Any],
is_new_conversation: bool,
variables: Optional[Dict[str, Any]] = None
) -> Optional[str]:
) -> tuple[Any, Any]:
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
if not is_new_conversation:
return None
return None, None
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return None
return None, None
statement = opening["statement"]
suggested_questions = opening["suggested_questions"]
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
if variables:
@@ -465,7 +465,7 @@ class AgentRunService:
placeholder = f"{{{{{var_name}}}}}"
statement = statement.replace(placeholder, str(var_value))
return statement
return statement, suggested_questions
@staticmethod
def _filter_citations(
@@ -599,13 +599,16 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
opening_statement=opening
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
@@ -657,11 +660,6 @@ class AgentRunService:
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件
)
@@ -845,14 +843,17 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation(
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id,
sub_agent=sub_agent,
opening_statement=opening
opening_statement=opening,
suggested_questions=suggested_questions
)
model_info = ModelInfo(
@@ -911,11 +912,6 @@ class AgentRunService:
message=message,
history=history,
context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files
):
if isinstance(chunk, int):
@@ -1061,7 +1057,8 @@ class AgentRunService:
workspace_id: uuid.UUID,
user_id: Optional[str],
sub_agent: bool = False,
opening_statement: Optional[str] = None
opening_statement: Optional[str] = None,
suggested_questions: Optional[List[str]] = None
) -> str:
"""确保会话存在(创建或验证)
@@ -1072,6 +1069,7 @@ class AgentRunService:
user_id: 用户ID
sub_agent: 是否为子代理
opening_statement: 开场白(新会话时作为第一条消息写入)
suggested_questions: 预设问题列表
Returns:
str: 会话ID
@@ -1115,7 +1113,7 @@ class AgentRunService:
conversation_id=uuid.UUID(new_conv_id),
role="assistant",
content=opening_statement,
meta_data={}
meta_data={"suggested_questions": suggested_questions}
)
logger.debug(f"已保存开场白到会话 {new_conv_id}")

View File

@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
)
from app.services.memory_perceptual_service import MemoryPerceptualService
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__)
config_logger = get_config_logger()
@@ -68,24 +65,22 @@ class MemoryAgentService:
if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作
if audit_logger:
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
)
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
)
raise ValueError(f"写入失败: {messages}")
@@ -338,10 +333,9 @@ class MemoryAgentService:
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -401,10 +395,10 @@ class MemoryAgentService:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
@@ -469,10 +463,9 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
@@ -492,16 +485,15 @@ class MemoryAgentService:
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
@@ -633,15 +625,15 @@ class MemoryAgentService:
total_time = time.time() - start_time
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
return {
"answer": summary,
@@ -651,16 +643,16 @@ class MemoryAgentService:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session
from typing import List, Optional
from sqlalchemy import desc, nullslast, or_, and_, cast, String
from typing import List, Optional, Dict, Any
import uuid
from fastapi import HTTPException
from app.models.user_model import User
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.end_user_model import EndUser, EndUser as EndUserModel
from app.models.memory_increment_model import MemoryIncrement
from app.repositories import (
@@ -49,44 +50,40 @@ def get_current_workspace_type(
def get_workspace_end_users(
db: Session,
workspace_id: uuid.UUID,
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
try:
# 查询应用ORM
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
if not apps_orm:
business_logger.info("工作空间下没有应用")
return []
# 提取所有 app_id
# app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
).order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users
except HTTPException:
raise
except Exception as e:
@@ -94,6 +91,85 @@ def get_workspace_end_users(
raise
def get_workspace_end_users_paginated(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None
) -> Dict[str, Any]:
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args:
db: 数据库会话
workspace_id: 工作空间ID
current_user: 当前用户
page: 页码从1开始
pagesize: 每页数量
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
Returns:
dict: 包含 items宿主列表和 total总记录数的字典
"""
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
try:
# 构建基础查询
base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
)
# 构建搜索条件过滤空字符串和None
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
# other_name 匹配始终生效id 匹配仅对 other_name 为空的记录生效
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_nameother_name 为空时匹配 id")
# 获取总记录数
total = base_query.count()
if total == 0:
business_logger.info("工作空间下没有宿主")
return {"items": [], "total": 0}
# 分页查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
end_users_orm = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)
).offset((page - 1) * pagesize).limit(pagesize).all()
# 转换为 Pydantic 模型
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
return {"items": end_users, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_memory_increment(
db: Session,
workspace_id: uuid.UUID,
@@ -638,7 +714,24 @@ def get_rag_content(
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
continue
# 4. 返回结果
# 4. 将所有 page_content 拼接后按角色分割为对话列表
merged_text = "\n".join(page_contents)
conversations = []
if merged_text.strip():
import re
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
parts = re.split(r'(user|assistant):', merged_text)
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
i = 1
while i < len(parts) - 1:
role = parts[i].strip()
content = parts[i + 1].strip()
# 将 content 中的 \n 还原为真实换行
content = content.replace("\\n", "\n")
if role in ("user", "assistant") and content:
conversations.append({"role": role, "content": content})
i += 2
result = {
"page": {
"page": page,
@@ -646,10 +739,10 @@ def get_rag_content(
"total": global_total,
"hasnext": offset_end < global_total,
},
"items": page_contents
"items": conversations
}
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)}")
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)}对话")
return result
except Exception as e:

View File

@@ -204,30 +204,35 @@ class MemoryForgetService:
end_user_id: str,
forgetting_threshold: float,
min_days_since_access: int,
limit: int = 20
) -> list[Dict[str, Any]]:
page: Optional[int] = None,
pagesize: Optional[int] = None
) -> Dict[str, Any]:
"""
获取待遗忘节点列表
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
Args:
connector: Neo4j 连接器
end_user_id: 组ID
forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数
limit: 返回节点数量限制
page: 页码可选从1开始
pagesize: 每页数量(可选)
Returns:
list: 待遗忘节点列表
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息(分页时)
"""
from datetime import timedelta
# 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区)
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
query = """
# 基础查询(用于获取总数)
count_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
@@ -235,10 +240,22 @@ class MemoryForgetService:
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
RETURN count(n) as total
"""
# 数据查询
data_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
elementId(n) as node_id,
labels(n)[0] as node_type,
CASE
CASE
WHEN n:Statement THEN n.statement
WHEN n:ExtractedEntity THEN n.name
WHEN n:MemorySummary THEN n.content
@@ -247,18 +264,32 @@ class MemoryForgetService:
n.activation_value as activation_value,
n.last_access_time as last_access_time
ORDER BY n.activation_value ASC
LIMIT $limit
"""
# 如果启用分页,添加 SKIP 和 LIMIT
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
data_query += " SKIP $skip LIMIT $limit"
params = {
'end_user_id': end_user_id,
'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str,
'limit': limit
'min_access_time_str': min_access_time_str
}
results = await connector.execute_query(query, **params)
# 获取总数(分页时需要)
total = 0
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
count_results = await connector.execute_query(count_query, **params)
if count_results:
total = count_results[0]['total']
# 添加分页参数
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
params['skip'] = (page - 1) * pagesize
params['limit'] = pagesize
results = await connector.execute_query(data_query, **params)
pending_nodes = []
for result in results:
# 将节点类型标签转换为小写
@@ -267,7 +298,7 @@ class MemoryForgetService:
node_type_label = 'entity'
elif node_type_label == 'memorysummary':
node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
@@ -278,7 +309,7 @@ class MemoryForgetService:
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else:
last_access_timestamp = 0
pending_nodes.append({
'node_id': str(result['node_id']),
'node_type': node_type_label,
@@ -286,8 +317,20 @@ class MemoryForgetService:
'activation_value': result['activation_value'],
'last_access_time': last_access_timestamp
})
return pending_nodes
# 构建返回结果
result: Dict[str, Any] = {'items': pending_nodes}
# 如果启用分页,添加分页信息
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
result['page'] = {
'page': page,
'pagesize': pagesize,
'total': total,
'hasnext': (page * pagesize) < total
}
return result
async def trigger_forgetting_cycle(
self,
@@ -636,7 +679,7 @@ class MemoryForgetService:
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 获取待遗忘节点列表前20个满足遗忘条件的节点
# 获取待遗忘节点列表
pending_nodes = []
try:
if end_user_id:
@@ -652,8 +695,7 @@ class MemoryForgetService:
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
limit=20
min_days_since_access=int(min_days)
)
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
@@ -661,24 +703,79 @@ class MemoryForgetService:
except Exception as e:
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 构建统计信息
# 构建统计信息(不包含 pending_nodes已分离到独立接口
stats = {
'activation_metrics': activation_metrics,
'node_distribution': node_distribution,
'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp() * 1000)
}
api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
f"trend_days={len(recent_trends)}"
)
return stats
async def get_pending_nodes(
self,
db: Session,
end_user_id: str,
config_id: Optional[UUID] = None,
page: int = 1,
pagesize: int = 10
) -> Dict[str, Any]:
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
Args:
db: 数据库会话
end_user_id: 组ID必填
config_id: 配置ID可选用于获取遗忘阈值
page: 页码从1开始默认1
pagesize: 每页数量默认10
Returns:
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
connector = forgetting_scheduler.connector
forgetting_threshold = config['forgetting_threshold']
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
api_logger.warning(
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
)
min_days = 7
# 调用内部方法获取分页数据
pending_nodes_result = await self._get_pending_forgetting_nodes(
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
page=page,
pagesize=pagesize
)
api_logger.info(
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
)
return pending_nodes_result
async def get_forgetting_curve(
self,
db: Session,

View File

@@ -243,28 +243,9 @@ class MemoryPerceptualService:
memory_config: MemoryConfig,
file: FileInput
):
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
if model_config is None or llm is None:
return None
multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name,
provider=model_config.provider,
@@ -286,15 +267,20 @@ class MemoryPerceptualService:
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
except FileNotFoundError as e:
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
try:
result = await llm.ainvoke(messages)
except Exception as e:
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
content = result.content
final_output = ""
if isinstance(content, list):

View File

@@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
return result
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
"""批量查询多个用户的记忆数量简化版本只返回total
Args:
end_user_ids: 用户ID列表
Returns:
Dict[str, int]: 以user_id为key的记忆数量字典
格式: {"user_id": total_count}
"""
if not end_user_ids:
return {}
result = await _neo4j_connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=end_user_ids,
)
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
data = {}
for row in result:
data[row["user_id"]] = row["total"]
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
for user_id in end_user_ids:
if user_id not in data:
data[user_id] = 0
return data
async def analytics_hot_memory_tags(
db: Session,
current_user: User,

View File

@@ -69,7 +69,8 @@ class ModelConfigService:
return items
@staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
def get_model_by_name(db: Session, name: str, provider: str | None = None,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model:
@@ -77,21 +78,22 @@ class ModelConfigService:
return model
@staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
ModelConfig]:
"""按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod
async def validate_model_config(
db: Session,
*,
model_name: str,
provider: str,
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello",
is_omni: bool = False
db: Session,
*,
model_name: str,
provider: str,
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello",
is_omni: bool = False
) -> Dict[str, Any]:
"""验证模型配置是否有效
@@ -158,13 +160,13 @@ class ModelConfigService:
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
# 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
@@ -200,11 +202,11 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "image":
# 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda",
@@ -212,7 +214,7 @@ class ModelConfigService:
)
elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}")
return {
"valid": True,
"message": "图片生成模型配置验证成功",
@@ -224,21 +226,21 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "video":
# 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest",
duration=5
)
elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None
return {
"valid": True,
"message": "视频生成模型配置验证成功",
@@ -265,7 +267,6 @@ class ModelConfigService:
# 提取详细的错误信息
error_message = str(e)
error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商)
@@ -354,14 +355,16 @@ class ModelConfigService:
return model
@staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -370,25 +373,27 @@ class ModelConfigService:
return model
@staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
# 检查 API Key 关联的模型配置类型
for model_config in api_key.model_configs:
# chat 和 llm 类型可以兼容
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = model_data.type
if not (config_type == request_type or
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
@@ -399,7 +404,7 @@ class ModelConfigService:
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
# BizCode.INVALID_PARAMETER
# )
# 创建组合模型
model_config_data = {
"tenant_id": tenant_id,
@@ -418,49 +423,51 @@ class ModelConfigService:
model = ModelConfigRepository.create(db, model_config_data)
db.flush()
# 关联 API Keys
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
model.api_keys.append(api_key)
db.commit()
db.refresh(model)
return model
@staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = existing_model.type
if not (config_type == request_type or
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER
)
# 更新基本信息
existing_model.name = model_data.name
# existing_model.type = model_data.type
@@ -471,14 +478,14 @@ class ModelConfigService:
existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联
existing_model.api_keys.clear()
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
existing_model.api_keys.append(api_key)
db.commit()
db.refresh(existing_model)
return existing_model
@@ -532,7 +539,7 @@ class ModelApiKeyService:
"""根据provider为多个ModelConfig创建API Key"""
created_keys = []
failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
@@ -540,10 +547,10 @@ class ModelApiKeyService:
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
@@ -553,7 +560,7 @@ class ModelApiKeyService:
ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key:
# 如果已存在,重新激活并更新
if existing_key.is_active:
@@ -566,14 +573,14 @@ class ModelApiKeyService:
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
created_keys.append(existing_key)
continue
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
@@ -589,7 +596,7 @@ class ModelApiKeyService:
# 记录验证失败的模型,但不抛出异常
failed_models.append(model_name)
continue
# 创建API Key
api_key_data = ModelApiKeyCreate(
model_config_ids=[model_config_id],
@@ -606,12 +613,12 @@ class ModelApiKeyService:
)
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
created_keys.append(api_key_obj)
if created_keys:
db.commit()
for key in created_keys:
db.refresh(key)
return created_keys, failed_models
@staticmethod
@@ -626,7 +633,7 @@ class ModelApiKeyService:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
@@ -650,15 +657,15 @@ class ModelApiKeyService:
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
db.commit()
db.refresh(existing_key)
return existing_key
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
@@ -691,7 +698,7 @@ class ModelApiKeyService:
# 获取关联的模型配置以获取模型类型
if existing_api_key.model_configs:
model_config = existing_api_key.model_configs[0]
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name or existing_api_key.model_name,
@@ -729,15 +736,15 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
return None
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
return None
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个
return api_keys[0]
@@ -760,20 +767,19 @@ class ModelApiKeyService:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService:
"""基础模型服务"""
@staticmethod
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
models = ModelBaseRepository.get_list(db, query)
provider_groups = {}
for m in models:
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
if tenant_id:
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
provider = m.provider
if provider not in provider_groups:
provider_groups[provider] = {
@@ -781,7 +787,7 @@ class ModelBaseService:
"models": []
}
provider_groups[provider]["models"].append(model_dict)
return list(provider_groups.values())
@staticmethod
@@ -823,10 +829,10 @@ class ModelBaseService:
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
model_config_data = {
"model_id": model_base_id,
"tenant_id": tenant_id,

View File

@@ -12,6 +12,9 @@ import base64
import csv
import io
import json
import re
import olefile
import struct
import zipfile
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -438,13 +441,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return True, {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file)
text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -542,7 +545,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str:
async def extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容
@@ -602,31 +605,75 @@ class MultimodalService:
try:
word_file = io.BytesIO(file_content)
doc = Document(word_file)
return '\n'.join(p.text for p in doc.paragraphs)
text_lines = []
for p in doc.paragraphs:
text = p.text.strip()
if text:
text_lines.append(text)
for table in doc.tables:
for row in table.rows:
for cell in row.cells:
text = cell.text.strip()
if text:
text_lines.append(text)
full_text = "\n".join(text_lines)
return full_text.strip() or "[docx 文件无文本内容]"
except Exception as e:
logger.error(f"提取 docx 文本失败: {e}")
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
return f"[docx 提取失败: {str(e)}]"
# 旧版 .docOLE2 格式)
# 旧版 .docOLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
try:
import olefile
ole = olefile.OleFileIO(io.BytesIO(file_content))
if not ole.exists('WordDocument'):
return "[doc 提取失败: 未找到 WordDocument 流]"
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
stream = ole.openstream('WordDocument').read()
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
# 简单提取:过滤出可打印字符段
try:
text = stream.decode('utf-16-le', errors='ignore')
except Exception:
text = stream.decode('latin-1', errors='ignore')
# 过滤控制字符,保留可打印内容
import re
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
text = re.sub(r' +', ' ', text).strip()
word_stream = ole.openstream('WordDocument').read()
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
table_stream = ole.openstream(table_name).read()
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
clx = table_stream[fc_clx: fc_clx + lcb_clx]
# 解析 CLX找到 PlcPcdpiece table
i, plc_pcd = 0, None
while i < len(clx):
clxt = clx[i]
if clxt == 0x01:
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
elif clxt == 0x02:
cb = struct.unpack_from('<I', clx, i + 1)[0]
plc_pcd = clx[i + 5: i + 5 + cb]
break
else:
break
if plc_pcd is None:
raise ValueError("PlcPcd not found")
# PlcPcd: (n+1) 个 CP4字节+ n 个 PCD8字节
n_pieces = (len(plc_pcd) - 4) // 12
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
parts = []
for k in range(n_pieces):
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
is_ansi = bool(fc_value & 0x40000000)
fc = fc_value & 0x3FFFFFFF
char_count = cp_array[k + 1] - cp_array[k]
if is_ansi:
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
else:
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
ole.close()
return text
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
return result.strip()
except Exception as e:
logger.error(f"提取 doc 文本失败: {e}")
return f"[doc 提取失败: {str(e)}]"

View File

@@ -1,26 +1,24 @@
"""基于分享链接的聊天服务"""
import uuid
import time
import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator
from deprecated import deprecated
from sqlalchemy.orm import Session
from app.repositories.model_repository import ModelApiKeyRepository
from app.services.memory_konwledges_server import write_rag
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_business_logger
from app.models import MultiAgentConfig
from app.models import ReleaseShare, AppRelease, Conversation
from app.repositories import knowledge_repository
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.services.release_share_service import ReleaseShareService
logger = get_business_logger()
@@ -118,6 +116,7 @@ class SharedChatService:
return conversation
@deprecated("Use the chat method under app_chat_service instead.")
async def chat(
self,
share_token: str,
@@ -136,10 +135,7 @@ class SharedChatService:
config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time()
actual_config_id = None
@@ -273,11 +269,6 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
@@ -324,6 +315,7 @@ class SharedChatService:
"elapsed_time": elapsed_time
}
@deprecated("Use the chat method under app_chat_service instead.")
async def chat_stream(
self,
share_token: str,
@@ -341,8 +333,6 @@ class SharedChatService:
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json
start_time = time.time()
@@ -486,11 +476,6 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
):
if isinstance(chunk, int):
total_tokens = chunk
@@ -585,6 +570,7 @@ class SharedChatService:
return conversations, total
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat(
self,
share_token: str,
@@ -680,6 +666,7 @@ class SharedChatService:
"elapsed_time": elapsed_time
}
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat_stream(
self,
share_token: str,

View File

@@ -138,7 +138,7 @@ class TenantService:
except Exception as e:
business_logger.error(f"删除租户失败: {str(e)}")
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(f"删除租户失败{str(e)}", code=BizCode.DB_ERROR)
# 租户用户管理
def get_tenant_users(
@@ -147,6 +147,7 @@ class TenantService:
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> List[UserModel]:
"""获取租户下的用户列表"""
@@ -155,6 +156,7 @@ class TenantService:
skip=skip,
limit=limit,
is_active=is_active,
is_superuser=is_superuser,
search=search
)
@@ -162,12 +164,14 @@ class TenantService:
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户下的用户数量"""
return self.user_repo.count_users_by_tenant(
tenant_id=tenant_id,
is_active=is_active,
is_superuser=is_superuser,
search=search
)

View File

@@ -472,6 +472,21 @@ class UserMemoryService:
# 定义允许更新的字段白名单
allowed_fields = {'other_name', 'aliases', 'meta_data'}
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
_user_placeholder_names = {'用户', '', 'User', 'I'}
# 过滤 other_name不允许设置为占位名称
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
del update_data['other_name']
# 过滤 aliases移除占位名称和非字符串值
if 'aliases' in update_data and update_data['aliases']:
update_data['aliases'] = [
a for a in update_data['aliases']
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
]
# 检查是否更新了 aliases 字段
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases

View File

@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
executions = self.execution_repo.get_by_conversation_id(
conversation_id=conversation_id,
status="completed",
limit_count=1
)
if executions:
last_state = executions[0].output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
# input_data["conv"] = conv_vars
# input_data["conv_messages"] = last_state.get("messages") or []
conv_messages = last_state.get("messages") or []
return conv_vars, conv_messages
return None
# ==================== 工作流执行 ====================
async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
history = self._get_history_info(conversation_id_uuid)
if history:
conv_vars, conv_messages = history
input_data["conv"] = conv_vars
input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
history = self._get_history_info(conversation_id_uuid)
if history:
conv_vars, conv_messages = history
input_data["conv"] = conv_vars
input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4()
async for event in execute_workflow_stream(