Merge branch 'refs/heads/develop' into feature/agent-tool_xjn
# Conflicts: # api/app/core/agent/langchain_agent.py # api/app/core/tools/mcp/client.py
This commit is contained in:
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -43,18 +44,17 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
files: list[FileInput],
|
||||
user_id: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
@@ -93,7 +93,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||
user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -140,13 +141,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
@@ -168,11 +169,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -229,6 +225,21 @@ class AppChatService:
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": result["content"]}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -264,20 +275,19 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
files: list[FileInput],
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 应用 features 配置
|
||||
@@ -319,7 +329,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||
config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -367,13 +378,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
@@ -411,11 +422,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -459,7 +465,7 @@ class AppChatService:
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[],
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
@@ -484,6 +490,22 @@ class AppChatService:
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": full_content}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -618,7 +640,6 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
|
||||
128
api/app/services/app_log_service.py
Normal file
128
api/app/services/app_log_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""应用日志服务层"""
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AppLogService:
|
||||
"""应用日志服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_repository = ConversationRepository(db)
|
||||
self.message_repository = MessageRepository(db)
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
is_draft: Optional[bool] = None,
|
||||
) -> Tuple[list[Conversation], int]:
|
||||
"""
|
||||
查询应用日志会话列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
page: 页码(从 1 开始)
|
||||
pagesize: 每页数量
|
||||
is_draft: 是否草稿会话(None 表示不过滤)
|
||||
|
||||
Returns:
|
||||
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
|
||||
# 使用 Repository 查询
|
||||
conversations, total = self.conversation_repository.list_app_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
is_draft=is_draft,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"total": total,
|
||||
"returned": len(conversations)
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total
|
||||
|
||||
def get_conversation_detail(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Conversation:
|
||||
"""
|
||||
查询会话详情(包含消息)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
conversation_id: 会话 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
Conversation: 包含消息的会话对象
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当会话不存在时
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
# 查询会话
|
||||
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = self.message_repository.get_messages_by_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 将消息附加到会话对象
|
||||
conversation.messages = messages
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
@@ -1084,7 +1084,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
return cleaned
|
||||
|
||||
exists = self.db.query(
|
||||
@@ -1096,7 +1095,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -1684,15 +1682,15 @@ class AppService:
|
||||
|
||||
return config.config_id
|
||||
|
||||
def _update_endusers_memory_config_by_workspace(
|
||||
def _update_endusers_memory_config_by_app(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
@@ -1701,8 +1699,8 @@ class AppService:
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
||||
workspace_id=workspace_id,
|
||||
updated_count = repo.batch_update_memory_config_id_by_app(
|
||||
app_id=app_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
@@ -1753,12 +1751,16 @@ class AppService:
|
||||
|
||||
miss_params = []
|
||||
if agent_cfg.default_model_config_id is None:
|
||||
miss_params.append("model config")
|
||||
miss_params.append("模型配置")
|
||||
|
||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||
miss_params.append("memory config")
|
||||
miss_params.append("记忆配置")
|
||||
if miss_params:
|
||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||
raise BusinessException(
|
||||
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
|
||||
BizCode.CONFIG_MISSING,
|
||||
context={"missing_params": miss_params},
|
||||
)
|
||||
|
||||
config = {
|
||||
"system_prompt": agent_cfg.system_prompt,
|
||||
@@ -1877,8 +1879,8 @@ class AppService:
|
||||
if memory_config_id:
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if app:
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(
|
||||
app.workspace_id, memory_config_id
|
||||
updated_count = self._update_endusers_memory_config_by_app(
|
||||
app_id, memory_config_id
|
||||
)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||
@@ -2014,7 +2016,7 @@ class AppService:
|
||||
|
||||
if memory_config_id:
|
||||
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
|
||||
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
|
||||
@@ -214,7 +214,7 @@ class ConversationService:
|
||||
|
||||
conversation.message_count += 1
|
||||
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
if conversation.message_count <= 2 and role == "user":
|
||||
conversation.title = (
|
||||
content[:50] + ("..." if len(content) > 50 else "")
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -449,15 +448,16 @@ class AgentRunService:
|
||||
features_config: Dict[str, Any],
|
||||
is_new_conversation: bool,
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
||||
if not is_new_conversation:
|
||||
return None
|
||||
return None, None
|
||||
opening = features_config.get("opening_statement", {})
|
||||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||
return None
|
||||
return None, None
|
||||
|
||||
statement = opening["statement"]
|
||||
suggested_questions = opening["suggested_questions"]
|
||||
|
||||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||||
if variables:
|
||||
@@ -465,7 +465,7 @@ class AgentRunService:
|
||||
placeholder = f"{{{{{var_name}}}}}"
|
||||
statement = statement.replace(placeholder, str(var_value))
|
||||
|
||||
return statement
|
||||
return statement, suggested_questions
|
||||
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
@@ -599,13 +599,16 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -657,11 +660,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -845,14 +843,17 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
sub_agent=sub_agent,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -911,11 +912,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -1061,7 +1057,8 @@ class AgentRunService:
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str],
|
||||
sub_agent: bool = False,
|
||||
opening_statement: Optional[str] = None
|
||||
opening_statement: Optional[str] = None,
|
||||
suggested_questions: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""确保会话存在(创建或验证)
|
||||
|
||||
@@ -1072,6 +1069,7 @@ class AgentRunService:
|
||||
user_id: 用户ID
|
||||
sub_agent: 是否为子代理
|
||||
opening_statement: 开场白(新会话时作为第一条消息写入)
|
||||
suggested_questions: 预设问题列表
|
||||
|
||||
Returns:
|
||||
str: 会话ID
|
||||
@@ -1115,7 +1113,7 @@ class AgentRunService:
|
||||
conversation_id=uuid.UUID(new_conv_id),
|
||||
role="assistant",
|
||||
content=opening_statement,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
@@ -68,24 +65,22 @@ class MemoryAgentService:
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||
|
||||
# 记录失败的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
@@ -338,10 +333,9 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -401,10 +395,10 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async def read_memory(
|
||||
@@ -469,10 +463,9 @@ class MemoryAgentService:
|
||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
@@ -492,16 +485,15 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -633,15 +625,15 @@ class MemoryAgentService:
|
||||
total_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": summary,
|
||||
@@ -651,16 +643,16 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.end_user_model import EndUser, EndUser as EndUserModel
|
||||
from app.models.memory_increment_model import MemoryIncrement
|
||||
|
||||
from app.repositories import (
|
||||
@@ -49,44 +50,40 @@ def get_current_workspace_type(
|
||||
|
||||
|
||||
def get_workspace_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# 查询应用(ORM)
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
|
||||
|
||||
if not apps_orm:
|
||||
business_logger.info("工作空间下没有应用")
|
||||
return []
|
||||
|
||||
|
||||
# 提取所有 app_id
|
||||
# app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return end_users
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -94,6 +91,85 @@ def get_workspace_end_users(
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_end_users_paginated(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
keyword: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
current_user: 当前用户
|
||||
page: 页码(从1开始)
|
||||
pagesize: 每页数量
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
|
||||
Returns:
|
||||
dict: 包含 items(宿主列表)和 total(总记录数)的字典
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 构建基础查询
|
||||
base_query = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
)
|
||||
|
||||
# 构建搜索条件(过滤空字符串和None)
|
||||
keyword = keyword.strip() if keyword else None
|
||||
|
||||
if keyword:
|
||||
keyword_pattern = f"%{keyword}%"
|
||||
# other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效
|
||||
base_query = base_query.filter(
|
||||
or_(
|
||||
EndUserModel.other_name.ilike(keyword_pattern),
|
||||
and_(
|
||||
or_(
|
||||
EndUserModel.other_name.is_(None),
|
||||
EndUserModel.other_name == "",
|
||||
),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)")
|
||||
|
||||
# 获取总记录数
|
||||
total = base_query.count()
|
||||
|
||||
if total == 0:
|
||||
business_logger.info("工作空间下没有宿主")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
# 分页查询
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
end_users_orm = base_query.order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
|
||||
# 转换为 Pydantic 模型
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return {"items": end_users, "total": total}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_memory_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
@@ -638,7 +714,24 @@ def get_rag_content(
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
# 4. 将所有 page_content 拼接后按角色分割为对话列表
|
||||
merged_text = "\n".join(page_contents)
|
||||
conversations = []
|
||||
if merged_text.strip():
|
||||
import re
|
||||
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
|
||||
parts = re.split(r'(user|assistant):', merged_text)
|
||||
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
|
||||
i = 1
|
||||
while i < len(parts) - 1:
|
||||
role = parts[i].strip()
|
||||
content = parts[i + 1].strip()
|
||||
# 将 content 中的 \n 还原为真实换行
|
||||
content = content.replace("\\n", "\n")
|
||||
if role in ("user", "assistant") and content:
|
||||
conversations.append({"role": role, "content": content})
|
||||
i += 2
|
||||
|
||||
result = {
|
||||
"page": {
|
||||
"page": page,
|
||||
@@ -646,10 +739,10 @@ def get_rag_content(
|
||||
"total": global_total,
|
||||
"hasnext": offset_end < global_total,
|
||||
},
|
||||
"items": page_contents
|
||||
"items": conversations
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -204,30 +204,35 @@ class MemoryForgetService:
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int,
|
||||
limit: int = 20
|
||||
) -> list[Dict[str, Any]]:
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
||||
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
limit: 返回节点数量限制
|
||||
|
||||
page: 页码(可选,从1开始)
|
||||
pagesize: 每页数量(可选)
|
||||
|
||||
Returns:
|
||||
list: 待遗忘节点列表
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息(分页时)
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||
|
||||
query = """
|
||||
|
||||
# 基础查询(用于获取总数)
|
||||
count_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
@@ -235,10 +240,22 @@ class MemoryForgetService:
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
# 数据查询
|
||||
data_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
elementId(n) as node_id,
|
||||
labels(n)[0] as node_type,
|
||||
CASE
|
||||
CASE
|
||||
WHEN n:Statement THEN n.statement
|
||||
WHEN n:ExtractedEntity THEN n.name
|
||||
WHEN n:MemorySummary THEN n.content
|
||||
@@ -247,18 +264,32 @@ class MemoryForgetService:
|
||||
n.activation_value as activation_value,
|
||||
n.last_access_time as last_access_time
|
||||
ORDER BY n.activation_value ASC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
|
||||
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
data_query += " SKIP $skip LIMIT $limit"
|
||||
|
||||
params = {
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str,
|
||||
'limit': limit
|
||||
'min_access_time_str': min_access_time_str
|
||||
}
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
|
||||
# 获取总数(分页时需要)
|
||||
total = 0
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
count_results = await connector.execute_query(count_query, **params)
|
||||
if count_results:
|
||||
total = count_results[0]['total']
|
||||
|
||||
# 添加分页参数
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
params['skip'] = (page - 1) * pagesize
|
||||
params['limit'] = pagesize
|
||||
|
||||
results = await connector.execute_query(data_query, **params)
|
||||
|
||||
pending_nodes = []
|
||||
for result in results:
|
||||
# 将节点类型标签转换为小写
|
||||
@@ -267,7 +298,7 @@ class MemoryForgetService:
|
||||
node_type_label = 'entity'
|
||||
elif node_type_label == 'memorysummary':
|
||||
node_type_label = 'summary'
|
||||
|
||||
|
||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||
last_access_time = result['last_access_time']
|
||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||
@@ -278,7 +309,7 @@ class MemoryForgetService:
|
||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||
else:
|
||||
last_access_timestamp = 0
|
||||
|
||||
|
||||
pending_nodes.append({
|
||||
'node_id': str(result['node_id']),
|
||||
'node_type': node_type_label,
|
||||
@@ -286,8 +317,20 @@ class MemoryForgetService:
|
||||
'activation_value': result['activation_value'],
|
||||
'last_access_time': last_access_timestamp
|
||||
})
|
||||
|
||||
return pending_nodes
|
||||
|
||||
# 构建返回结果
|
||||
result: Dict[str, Any] = {'items': pending_nodes}
|
||||
|
||||
# 如果启用分页,添加分页信息
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
result['page'] = {
|
||||
'page': page,
|
||||
'pagesize': pagesize,
|
||||
'total': total,
|
||||
'hasnext': (page * pagesize) < total
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
@@ -636,7 +679,7 @@ class MemoryForgetService:
|
||||
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||
# 获取待遗忘节点列表
|
||||
pending_nodes = []
|
||||
try:
|
||||
if end_user_id:
|
||||
@@ -652,8 +695,7 @@ class MemoryForgetService:
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
limit=20
|
||||
min_days_since_access=int(min_days)
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
||||
@@ -661,24 +703,79 @@ class MemoryForgetService:
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 构建统计信息
|
||||
|
||||
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||
stats = {
|
||||
'activation_metrics': activation_metrics,
|
||||
'node_distribution': node_distribution,
|
||||
'recent_trends': recent_trends,
|
||||
'pending_nodes': pending_nodes,
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
||||
f"trend_days={len(recent_trends)}"
|
||||
)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def get_pending_nodes(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
config_id: Optional[UUID] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 组ID(必填)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
|
||||
Returns:
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
connector = forgetting_scheduler.connector
|
||||
forgetting_threshold = config['forgetting_threshold']
|
||||
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
api_logger.warning(
|
||||
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||
)
|
||||
min_days = 7
|
||||
|
||||
# 调用内部方法获取分页数据
|
||||
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||
)
|
||||
|
||||
return pending_nodes_result
|
||||
|
||||
async def get_forgetting_curve(
|
||||
self,
|
||||
db: Session,
|
||||
|
||||
@@ -243,28 +243,9 @@ class MemoryPerceptualService:
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
if model_config is None or llm is None:
|
||||
return None
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
@@ -286,15 +267,20 @@ class MemoryPerceptualService:
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
except FileNotFoundError as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
try:
|
||||
result = await llm.ainvoke(messages)
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
|
||||
@@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
||||
return result
|
||||
|
||||
|
||||
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
|
||||
"""批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||
|
||||
Args:
|
||||
end_user_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 以user_id为key的记忆数量字典
|
||||
格式: {"user_id": total_count}
|
||||
"""
|
||||
if not end_user_ids:
|
||||
return {}
|
||||
|
||||
result = await _neo4j_connector.execute_query(
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||
end_user_ids=end_user_ids,
|
||||
)
|
||||
|
||||
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
|
||||
data = {}
|
||||
for row in result:
|
||||
data[row["user_id"]] = row["total"]
|
||||
|
||||
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
|
||||
for user_id in end_user_ids:
|
||||
if user_id not in data:
|
||||
data[user_id] = 0
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
|
||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||
ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -158,13 +160,13 @@ class ModelConfigService:
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -200,11 +202,11 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
@@ -212,7 +214,7 @@ class ModelConfigService:
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
@@ -224,21 +226,21 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -370,25 +373,27 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
@@ -399,7 +404,7 @@ class ModelConfigService:
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -418,49 +423,51 @@ class ModelConfigService:
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = existing_model.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
# existing_model.type = model_data.type
|
||||
@@ -471,14 +478,14 @@ class ModelConfigService:
|
||||
existing_model.is_public = model_data.is_public
|
||||
if "load_balance_strategy" in model_data.model_fields_set:
|
||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
@@ -532,7 +539,7 @@ class ModelApiKeyService:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
failed_models = [] # 记录验证失败的模型
|
||||
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
@@ -540,10 +547,10 @@ class ModelApiKeyService:
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -553,7 +560,7 @@ class ModelApiKeyService:
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
@@ -566,14 +573,14 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -589,7 +596,7 @@ class ModelApiKeyService:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
failed_models.append(model_name)
|
||||
continue
|
||||
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
@@ -606,12 +613,12 @@ class ModelApiKeyService:
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
|
||||
return created_keys, failed_models
|
||||
|
||||
@staticmethod
|
||||
@@ -626,7 +633,7 @@ class ModelApiKeyService:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -650,15 +657,15 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -691,7 +698,7 @@ class ModelApiKeyService:
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
@@ -729,15 +736,15 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@@ -760,20 +767,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
@@ -781,7 +787,7 @@ class ModelBaseService:
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -823,10 +829,10 @@ class ModelBaseService:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@@ -12,6 +12,9 @@ import base64
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import olefile
|
||||
import struct
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
@@ -438,13 +441,13 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -542,7 +545,7 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
async def extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
@@ -602,31 +605,75 @@ class MultimodalService:
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
return '\n'.join(p.text for p in doc.paragraphs)
|
||||
text_lines = []
|
||||
for p in doc.paragraphs:
|
||||
text = p.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
for cell in row.cells:
|
||||
text = cell.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
full_text = "\n".join(text_lines)
|
||||
return full_text.strip() or "[docx 文件无文本内容]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 docx 文本失败: {e}")
|
||||
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
|
||||
return f"[docx 提取失败: {str(e)}]"
|
||||
|
||||
# 旧版 .doc(OLE2 格式)
|
||||
# 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
|
||||
try:
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||
if not ole.exists('WordDocument'):
|
||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||
stream = ole.openstream('WordDocument').read()
|
||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||
# 简单提取:过滤出可打印字符段
|
||||
try:
|
||||
text = stream.decode('utf-16-le', errors='ignore')
|
||||
except Exception:
|
||||
text = stream.decode('latin-1', errors='ignore')
|
||||
# 过滤控制字符,保留可打印内容
|
||||
import re
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
text = re.sub(r' +', ' ', text).strip()
|
||||
word_stream = ole.openstream('WordDocument').read()
|
||||
|
||||
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
|
||||
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
|
||||
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
|
||||
table_stream = ole.openstream(table_name).read()
|
||||
|
||||
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
|
||||
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
|
||||
clx = table_stream[fc_clx: fc_clx + lcb_clx]
|
||||
|
||||
# 解析 CLX,找到 PlcPcd(piece table)
|
||||
i, plc_pcd = 0, None
|
||||
while i < len(clx):
|
||||
clxt = clx[i]
|
||||
if clxt == 0x01:
|
||||
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
|
||||
elif clxt == 0x02:
|
||||
cb = struct.unpack_from('<I', clx, i + 1)[0]
|
||||
plc_pcd = clx[i + 5: i + 5 + cb]
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if plc_pcd is None:
|
||||
raise ValueError("PlcPcd not found")
|
||||
|
||||
# PlcPcd: (n+1) 个 CP(4字节)+ n 个 PCD(8字节)
|
||||
n_pieces = (len(plc_pcd) - 4) // 12
|
||||
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
|
||||
|
||||
parts = []
|
||||
for k in range(n_pieces):
|
||||
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
|
||||
is_ansi = bool(fc_value & 0x40000000)
|
||||
fc = fc_value & 0x3FFFFFFF
|
||||
char_count = cp_array[k + 1] - cp_array[k]
|
||||
|
||||
if is_ansi:
|
||||
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
|
||||
else:
|
||||
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
|
||||
|
||||
ole.close()
|
||||
return text
|
||||
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
|
||||
return result.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取 doc 文本失败: {e}")
|
||||
return f"[doc 提取失败: {str(e)}]"
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import MultiAgentConfig
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
||||
|
||||
return conversation
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
||||
config_id = actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
@@ -273,11 +269,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -324,6 +315,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -341,8 +333,6 @@ class SharedChatService:
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
@@ -486,11 +476,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
@@ -585,6 +570,7 @@ class SharedChatService:
|
||||
|
||||
return conversations, total
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -680,6 +666,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
|
||||
@@ -138,7 +138,7 @@ class TenantService:
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"删除租户失败: {str(e)}")
|
||||
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
# 租户用户管理
|
||||
def get_tenant_users(
|
||||
@@ -147,6 +147,7 @@ class TenantService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[UserModel]:
|
||||
"""获取租户下的用户列表"""
|
||||
@@ -155,6 +156,7 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
@@ -162,12 +164,14 @@ class TenantService:
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
return self.user_repo.count_users_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
|
||||
@@ -472,6 +472,21 @@ class UserMemoryService:
|
||||
# 定义允许更新的字段白名单
|
||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||
|
||||
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
||||
_user_placeholder_names = {'用户', '我', 'User', 'I'}
|
||||
|
||||
# 过滤 other_name:不允许设置为占位名称
|
||||
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
||||
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
|
||||
del update_data['other_name']
|
||||
|
||||
# 过滤 aliases:移除占位名称和非字符串值
|
||||
if 'aliases' in update_data and update_data['aliases']:
|
||||
update_data['aliases'] = [
|
||||
a for a in update_data['aliases']
|
||||
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
|
||||
]
|
||||
|
||||
# 检查是否更新了 aliases 字段
|
||||
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
|
||||
|
||||
|
||||
@@ -561,6 +561,24 @@ class WorkflowService:
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
|
||||
executions = self.execution_repo.get_by_conversation_id(
|
||||
conversation_id=conversation_id,
|
||||
status="completed",
|
||||
limit_count=1
|
||||
)
|
||||
|
||||
if executions:
|
||||
last_state = executions[0].output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
# input_data["conv"] = conv_vars
|
||||
# input_data["conv_messages"] = last_state.get("messages") or []
|
||||
conv_messages = last_state.get("messages") or []
|
||||
return conv_vars, conv_messages
|
||||
return None
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -634,18 +652,11 @@ class WorkflowService:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
result = await execute_workflow(
|
||||
@@ -807,17 +818,11 @@ class WorkflowService:
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
message_id = uuid.uuid4()
|
||||
async for event in execute_workflow_stream(
|
||||
|
||||
Reference in New Issue
Block a user