diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 43f177ef..3b4e5a25 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db from app.dependencies import get_current_user, cur_workspace_access_guard from app.models import User @@ -661,6 +661,11 @@ async def draft_run( data=result, msg="工作流任务执行成功" ) + else: + return fail( + msg="未知应用类型", + code=422 + ) @router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行") diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b7da943c..46fe3043 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -9,7 +9,7 @@ from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User -from app.repositories import knowledge_repository +from app.repositories import knowledge_repository, WorkspaceRepository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse from app.services import task_service, workspace_service @@ -616,8 +616,10 @@ async def get_knowledge_type_stats_api( @router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) async def get_hot_memory_tags_by_user_api( end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + language_type: Optional[str] ="zh", limit: int = Query(20, description="返回标签数量限制"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session=Depends(get_db), ): """ 获取指定用户的热门记忆标签 @@ -628,10 +630,22 @@ async def get_hot_memory_tags_by_user_api( ... ] """ + + workspace_id=current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None + api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") try: result = await memory_agent_service.get_hot_memory_tags_by_user( end_user_id=end_user_id, + language_type=language_type, + model_id=model_id, limit=limit ) return success(data=result, msg="获取热门记忆标签成功") diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 64991f4d..9cf66749 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -20,6 +20,7 @@ router = APIRouter( @router.get("/short_term") async def short_term_configs( end_user_id: str, + language_type:Optional[str] = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index a96c7a52..d99eb47e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -12,6 +12,7 @@ from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode from app.core.api_key_utils import timestamp_to_datetime +from app.services.memory_base_service import Translation_English from app.services.user_memory_service import ( UserMemoryService, analytics_memory_types, @@ -20,7 +21,7 @@ from app.services.user_memory_service import ( from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest - +from app.repositories.workspace_repository import WorkspaceRepository from app.schemas.end_user_schema import ( EndUserProfileResponse, EndUserProfileUpdate, @@ -44,6 +45,7 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( end_user_id: str, + language_type: str = "zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -53,10 +55,18 @@ async def get_memory_insight_report_api( 此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。 如需生成新的洞察报告,请使用专门的生成接口。 """ + workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_memory_insight(db, end_user_id) + result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type) if result["is_cached"]: api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") @@ -72,6 +82,7 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( end_user_id: str, + language_type: str="zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -81,10 +92,18 @@ async def get_user_summary_api( 此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。 如需生成新的用户摘要,请使用专门的生成接口。 """ + workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id) + result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -253,7 +272,6 @@ async def get_graph_data_api( depth=depth, center_node_id=center_node_id ) - # 检查是否有错误消息 if "message" in result and result["statistics"]["total_nodes"] == 0: api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") @@ -278,7 +296,13 @@ async def get_end_user_profile( db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None # 检查用户是否已选择工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") @@ -296,7 +320,6 @@ async def get_end_user_profile( if not end_user: api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - # 构建响应数据 profile_data = EndUserProfileResponse( id=end_user.id, @@ -396,12 +419,21 @@ async def update_end_user_profile( return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str, +async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh", current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): + workspace_id=current_user.current_workspace_id + workspace_repo = WorkspaceRepository(db) + workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) + + if workspace_models: + model_id = workspace_models.get("llm", None) + else: + model_id = None MemoryEntity = MemoryEntityService(id, label) - timeline_memories_result = await MemoryEntity.get_timeline_memories_server() + timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type) + return success(data=timeline_memories_result, msg="共同记忆时间线") @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index ad03fec1..6721d7b0 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -54,6 +54,8 @@ class WorkflowExecutor: self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) + self.start_node_id = None + self.checkpoint_config = RunnableConfig( configurable={ "thread_id": uuid.uuid4(), @@ -131,77 +133,12 @@ class WorkflowExecutor: for node in self.workflow_config.get("nodes") if node.get("type") in [NodeType.LOOP, NodeType.ITERATION] ], # loop, iteration node id - "looping": False # loop runing flag, only use in loop node,not use in main loop + "looping": False, # loop runing flag, only use in loop node,not use in main loop + "activate": { + self.start_node_id: True + } } - def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: - """分析 End 节点的前缀配置 - - 检查每个 End 节点的模板,找到直接上游节点的引用, - 提取该引用之前的前缀部分。 - - Returns: - 元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合}) - """ - import re - - prefixes = {} - adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点 - - # 找到所有 End 节点 - end_nodes = [node for node in self.nodes if node.get("type") == "end"] - logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") - - for end_node in end_nodes: - end_node_id = end_node.get("id") - output_template = end_node.get("config", {}).get("output") - - logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") - - if not output_template: - continue - - # 找到所有直接连接到 End 节点的上游节点 - direct_upstream_nodes = [] - for edge in self.edges: - if edge.get("target") == end_node_id: - source_node_id = edge.get("source") - direct_upstream_nodes.append(source_node_id) - - logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") - - # 查找模板中引用了哪些节点 - # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) - pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}' - matches = list(re.finditer(pattern, output_template)) - - logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") - - # 找到第一个直接上游节点的引用 - for match in matches: - referenced_node_id = match.group(1) - logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") - - if referenced_node_id in direct_upstream_nodes: - # 这是直接上游节点的引用,提取前缀 - prefix = output_template[:match.start()] - - logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") - - # 标记这个节点为"相邻且被引用" - adjacent_and_referenced.add(referenced_node_id) - - if prefix: - prefixes[referenced_node_id] = prefix - logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") - - # 只处理第一个直接上游节点的引用 - break - - logger.info(f"[前缀分析] 最终配置: {prefixes}") - logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") - return prefixes, adjacent_and_referenced - def _build_final_output(self, result, elapsed_time): node_outputs = result.get("node_outputs", {}) final_output = self._extract_final_output(node_outputs) @@ -231,10 +168,12 @@ class WorkflowExecutor: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") - graph = GraphBuilder( + builder = GraphBuilder( self.workflow_config, stream=stream, - ).build() + ) + self.start_node_id = builder.start_node_id + graph = builder.build() logger.info(f"工作流图构建完成: execution_id={self.execution_id}") return graph @@ -375,13 +314,15 @@ class WorkflowExecutor: payload = data.get("payload", {}) node_name = payload.get("name") + if node_name and node_name.startswith("nop"): + continue + if event_type == "task": # Node starts execution inputv = payload.get("input", {}) - variables = inputv.get("variables", {}) - variables_sys = variables.get("sys", {}) + if not inputv.get("activate", {}).get(node_name): + continue conversation_id = input_data.get("conversation_id") - execution_id = variables_sys.get("execution_id") logger.info(f"[NODE-START] Node starts execution: {node_name} " f"- execution_id: {self.execution_id}") @@ -390,18 +331,17 @@ class WorkflowExecutor: "data": { "node_id": node_name, "conversation_id": conversation_id, - "execution_id": execution_id, - "timestamp": data.get("timestamp") + "execution_id": self.execution_id, + "timestamp": data.get("timestamp"), } } elif event_type == "task_result": # Node execution completed result = payload.get("result", {}) - inputv = result.get("input", {}) - variables = inputv.get("variables", {}) - variables_sys = variables.get("sys", {}) + if not result.get("activate", {}).get(node_name): + continue + conversation_id = input_data.get("conversation_id") - execution_id = variables_sys.get("execution_id") logger.info(f"[NODE-END] Node execution completed: {node_name} " f"- execution_id: {self.execution_id}") @@ -410,7 +350,7 @@ class WorkflowExecutor: "data": { "node_id": node_name, "conversation_id": conversation_id, - "execution_id": execution_id, + "execution_id": self.execution_id, "timestamp": data.get("timestamp"), "state": result.get("node_outputs", {}).get(node_name), } diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index b75b867e..5b9388fc 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -1,14 +1,16 @@ import logging import uuid +from collections import defaultdict from typing import Any -from langgraph.graph.state import CompiledStateGraph, StateGraph -from langgraph.graph import START, END from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import START, END +from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.types import Send from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory -from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES logger = logging.getLogger(__name__) @@ -28,7 +30,10 @@ class GraphBuilder: self.start_node_id = None self.end_node_ids = [] - self.graph: StateGraph | CompiledStateGraph | None = None + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.add_edges() + # EDGES MUST BE ADDED AFTER NODES ARE ADDED. @property def nodes(self) -> list[dict[str, Any]]: @@ -39,74 +44,98 @@ class GraphBuilder: return self.workflow_config.get("edges", []) def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: - """分析 End 节点的前缀配置 + """ + Analyze the prefix configuration for End nodes. - 检查每个 End 节点的模板,找到直接上游节点的引用, - 提取该引用之前的前缀部分。 + This function scans each End node's output template, identifies + references to its direct upstream nodes, and extracts the prefix + string appearing before the first reference. Returns: - 元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合}) + tuple: + - dict[str, str]: Mapping from upstream node ID to its End node prefix + - set[str]: Set of node IDs that are directly adjacent to End nodes and referenced """ import re prefixes = {} - adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点 + adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced # 找到所有 End 节点 end_nodes = [node for node in self.nodes if node.get("type") == "end"] - logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") + logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes") for end_node in end_nodes: end_node_id = end_node.get("id") output_template = end_node.get("config", {}).get("output") - logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") + logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}") if not output_template: continue - # 查找模板中引用了哪些节点 - # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) + # Find all node references in the template + # Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces) pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}' matches = list(re.finditer(pattern, output_template)) - logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") + logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用") - # 找到所有直接连接到 End 节点的上游节点 + # Identify all direct upstream nodes connected to the End node direct_upstream_nodes = [] for edge in self.edges: if edge.get("target") == end_node_id: source_node_id = edge.get("source") direct_upstream_nodes.append(source_node_id) - logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") + logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}") # 找到第一个直接上游节点的引用 for match in matches: referenced_node_id = match.group(1) - logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") + logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}") if referenced_node_id in direct_upstream_nodes: # 这是直接上游节点的引用,提取前缀 prefix = output_template[:match.start()] - logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + logger.info(f"[Prefix Analysis] " + f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'") # 标记这个节点为"相邻且被引用" adjacent_and_referenced.add(referenced_node_id) if prefix: prefixes[referenced_node_id] = prefix - logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") + logger.info(f"[Prefix Analysis] " + f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'") # 只处理第一个直接上游节点的引用 break - logger.info(f"[前缀分析] 最终配置: {prefixes}") - logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") + logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}") + logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}") return prefixes, adjacent_and_referenced def add_nodes(self): + """Add all nodes from the workflow configuration to the state graph. + + This method handles: + - Creation of node instances using NodeFactory. + - Special handling for start, end, and cycle nodes. + - Injection of End node prefixes for streaming mode. + - Marking nodes as adjacent to End nodes if referenced. + - Wrapping node run methods as async functions or async generators + depending on streaming mode. + + Notes: + Loop nodes (nodes with `cycle` property) are handled separately + via CycleGraphNode when building subgraphs. + + Returns: + None + """ + # Analyze End node prefixes if in stream mode end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set()) for node in self.nodes: @@ -114,21 +143,21 @@ class GraphBuilder: node_id = node.get("id") cycle_node = node.get("cycle") if cycle_node: - # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 + # Nodes within a loop subgraph are constructed by CycleGraphNode if not self.subgraph: continue - # 记录 start 和 end 节点 ID + # Record start and end node IDs if node_type in [NodeType.START, NodeType.CYCLE_START]: self.start_node_id = node_id elif node_type == NodeType.END: self.end_node_ids.append(node_id) - # 创建节点实例(现在 start 和 end 也会被创建) + # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph node_instance = NodeFactory.create_node(node, self.workflow_config) - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]: + if node_type in BRANCH_NODES: # Find all edges whose source is the current node related_edge = [edge for edge in self.edges if edge.get("source") == node_id] @@ -142,26 +171,23 @@ class GraphBuilder: related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" if node_instance: - # 如果是流式模式,且节点有 End 前缀配置,注入配置 + # Inject End node prefix configuration if in stream mode if self.stream and node_id in end_prefixes: - # 将 End 前缀配置注入到节点实例 node_instance._end_node_prefix = end_prefixes[node_id] - logger.info(f"为节点 {node_id} 注入 End 前缀配置") + logger.info(f"Injected End prefix for node {node_id}") - # 如果是流式模式,标记节点是否与 End 相邻且被引用 + # Mark nodes as adjacent and referenced to End node in stream mode if self.stream: node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced if node_id in adjacent_and_referenced: - logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") + logger.info(f"Node {node_id} marked as adjacent and referenced to End node") - # 包装节点的 run 方法 - # 使用函数工厂避免闭包问题 + # Wrap node's run method to avoid closure issues if self.stream: - # 流式模式:创建 async generator 函数 - # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state + # Stream mode: create an async generator function + # LangGraph collects all yielded values; the last yielded dictionary is merged into the state def make_stream_func(inst): async def node_func(state: WorkflowState): - # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") async for item in inst.run_stream(state): yield item @@ -169,7 +195,7 @@ class GraphBuilder: self.graph.add_node(node_id, make_stream_func(node_instance)) else: - # 非流式模式:创建 async function + # Non-stream mode: create an async function def make_func(inst): async def node_func(state: WorkflowState): return await inst.run(state) @@ -178,45 +204,110 @@ class GraphBuilder: self.graph.add_node(node_id, make_func(node_instance)) - logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})") + logger.debug(f"Added node: {node_id} (type={node_type}, stream={self.stream})") def add_edges(self): + """Add all edges (normal, waiting, and conditional) to the state graph. + + This method handles: + - Connecting the START node to the workflow's start node. + - Collecting waiting edges for nodes with multiple sources. + - Collecting conditional edges for routing to NOP nodes. + - Adding NOP nodes for conditional branches to allow later merging. + - Wrapping routing logic in a router function that evaluates conditions. + - Connecting End nodes to the global END node. + + Notes: + - NOP nodes are used to ensure that multiple branches can merge + correctly without modifying the workflow state. + - Waiting edges are automatically handled by LangGraph to schedule + nodes only after all sources are activated. + + Returns: + None + """ + # Connect the START node to the workflow's start node if self.start_node_id: self.graph.add_edge(START, self.start_node_id) - logger.debug(f"添加边: START -> {self.start_node_id}") + logger.debug(f"Added edge: START -> {self.start_node_id}") + + # Collect all sources for each target node for normal/waiting edges + waiting_edges = defaultdict(list) + # Collect all conditional edges for each source node to construct routing + conditional_edges = defaultdict(list) for edge in self.edges: source = edge.get("source") target = edge.get("target") - edge_type = edge.get("type") condition = edge.get("condition") + edge_type = edge.get("type") - # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) - if source == self.start_node_id: - # 但要连接 start 到下一个节点 - self.graph.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - continue - - # # 处理到 end 节点的边 - # if target in end_node_ids: - # # 连接到 end 节点 - # workflow.add_edge(source, target) - # logger.debug(f"添加边: {source} -> {target}") - # continue - - # 跳过错误边(在节点内部处理) + # Skip error edges (handled within nodes) if edge_type == "error": continue if condition: - # 条件边 - def make_router(cond, tgt): - """Dynamically generate a conditional router function to ensure each branch has a unique name.""" + # Conditional edges: group by source node + conditional_edges[source].append({ + "target": target, + "condition": condition, + "label": edge.get("label") + }) + else: + # Normal edges: group by target node (used for waiting edges) + waiting_edges[target].append(source) - def router_fn(state: WorkflowState): + # Add conditional edges + for source_node, branches in conditional_edges.items(): + def make_router(src, branch_list): + """reate a router function for each source node that routes to a NOP node for later merging.""" + def make_branch_node(node_name, targets): + def node(s): + # NOTE: NOP NODE MUST NOT MODIFY STATE + return { + "activate": { + node_id: s["activate"][node_name] + for node_id in targets + } + } + + return node + + unique_branch = {} + for branch in branch_list: + if branch.get("label") not in unique_branch.keys(): + nop_node_name = f"nop_{uuid.uuid4().hex[:8]}" + logger.info(f"Binding NOP: {source_node} {branch.get('label')} -> {nop_node_name}") + unique_branch[branch["label"]] = { + "condition": branch["condition"], + "node": { + "name": nop_node_name, + }, + "target": [branch["target"]] + } + else: + unique_branch[branch["label"]]["target"].append(branch["target"]) + + # Add NOP nodes and connect them to downstream nodes + for label, branch_info in unique_branch.items(): + self.graph.add_node( + branch_info["node"]["name"], + make_branch_node( + branch_info["node"]["name"], + branch_info["target"] + ) + ) + for target in branch_info["target"]: + waiting_edges[target].append(branch_info["node"]["name"]) + + def router_fn(state: WorkflowState) -> list[Send]: + branch_activate = [] + new_state = state.copy() + new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate + + for label, branch in unique_branch.items(): if evaluate_condition( - cond, + branch["condition"], state.get("variables", {}), state.get("runtime_vars", {}), { @@ -225,30 +316,45 @@ class GraphBuilder: "user_id": state.get("user_id") } ): - return tgt - return END + logger.debug(f"Conditional routing {src}: selected branch {label}") + new_state["activate"][branch["node"]["name"]] = True + continue + new_state["activate"][branch["node"]["name"]] = False + for label, branch in unique_branch.items(): + branch_activate.append( + Send( + branch['node']['name'], + new_state + ) + ) + return branch_activate - # 动态修改函数名,避免重复 - router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}" - return router_fn + # Dynamically set function name + router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{src}" + return router_fn - router_fn = make_router(condition, target) - self.graph.add_conditional_edges(source, router_fn) - logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + router_fn = make_router(source_node, branches) + self.graph.add_conditional_edges(source_node, router_fn) + logger.debug(f"Added conditional edges: {source_node} -> {[b['target'] for b in branches]}") + + # Add normal/waiting edges + for target, sources in waiting_edges.items(): + if len(sources) == 1: + # Single source: normal edge + self.graph.add_edge(sources[0], target) + logger.debug(f"Added edge: {sources[0]} -> {target}") else: - # 普通边 - self.graph.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") + # Multiple sources: waiting edge + self.graph.add_edge(sources, target) + logger.debug(f"Added waiting edge: {sources} -> {target}") - # 从 end 节点连接到 END + # Connect End nodes to the global END node for end_node_id in self.end_node_ids: self.graph.add_edge(end_node_id, END) - logger.debug(f"添加边: {end_node_id} -> END") + logger.debug(f"Added edge: {end_node_id} -> END") return def build(self) -> CompiledStateGraph: - self.graph = StateGraph(WorkflowState) - self.add_nodes() - self.add_edges() # 添加边必须在添加节点之后 checkpointer = InMemorySaver() - return self.graph.compile(checkpointer=checkpointer) + self.graph = self.graph.compile(checkpointer=checkpointer) + return self.graph diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 96f68ce8..6f2583b4 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) + self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index b31213d8..0c015c89 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -7,18 +7,26 @@ import asyncio import logging from abc import ABC, abstractmethod -from typing import Any +from typing import Any, AsyncGenerator from langchain_core.messages import AIMessage from langgraph.config import get_stream_writer from typing_extensions import TypedDict, Annotated from app.core.config import settings +from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) +def merget_activate_state(x, y): + return { + k: x.get(k, False) or y.get(k, False) + for k in set(x) | set(y) + } + + class WorkflowState(TypedDict): """Workflow state @@ -60,6 +68,9 @@ class WorkflowState(TypedDict): # Format: {node_id: {"chunks": [...], "full_content": "..."}} streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] + # node activate status + activate: Annotated[dict[str, bool], merget_activate_state] + class BaseNode(ABC): """节点基类 @@ -84,6 +95,47 @@ class BaseNode(ABC): self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} + self.variable_updater = False + + def check_activate(self, state: WorkflowState): + """Check if the current node is activated in the workflow state. + + Args: + state (WorkflowState): The current workflow state containing the 'activate' dict. + + Returns: + bool: True if the node is activated, False otherwise. + """ + return state["activate"][self.node_id] + + def trans_activate(self, state: WorkflowState): + """Transform the activation state for downstream nodes. + + This method collects all downstream nodes (excluding branch nodes) + connected to the current node and returns a dict indicating whether + each of these nodes should be activated based on the current node's state. + The current node itself is also included in the returned activation dict. + + Args: + state (WorkflowState): The current workflow state. + + Returns: + dict: A dict with a single key 'activate', mapping node IDs to + their activation status (True/False). + """ + edges = self.workflow_config.get("edges") + under_stream_nodes = [ + edge.get("target") + for edge in edges + if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES + ] + return { + "activate": { + node_id: self.check_activate(state) + for node_id in under_stream_nodes + } | {self.node_id: self.check_activate(state)} + } + @abstractmethod async def execute(self, state: WorkflowState) -> Any: """执行节点业务逻辑(非流式) @@ -99,13 +151,13 @@ class BaseNode(ABC): Examples: >>> # LLM 节点 - >>> return "这是 AI 的回复" + >>> "这是 AI 的回复" >>> # Transform 节点 - >>> return {"processed_data": [...]} + >>> {"processed_data": [...]} >>> # Start/End 节点 - >>> return {"message": "开始", "conversation_id": "xxx"} + >>> {"message": "开始", "conversation_id": "xxx"} """ pass @@ -126,14 +178,14 @@ class BaseNode(ABC): 业务数据(chunk)或完成标记 Examples: - >>> # 流式 LLM 节点 - >>> full_response = "" - >>> async for chunk in llm.astream(prompt): - ... full_response += chunk - ... yield chunk # yield 文本片段 - >>> - >>> # 最后 yield 完成标记 - >>> yield {"__final__": True, "result": AIMessage(content=full_response)} + # 流式 LLM 节点 + full_response = "" + async for chunk in llm.astream(prompt): + full_response += chunk + yield chunk # yield 文本片段 + + # 最后 yield 完成标记 + yield {"__final__": True, "result": AIMessage(content=full_response)} """ result = await self.execute(state) # 默认实现:直接 yield 完成标记 @@ -146,7 +198,7 @@ class BaseNode(ABC): 是否支持流式输出 """ # 检查子类是否重写了 execute_stream 方法 - return self.execute_stream.__func__ != BaseNode.execute_stream.__func__ + return self.__class__.execute_stream is not BaseNode.execute_stream def get_timeout(self) -> int: """获取超时时间(秒) @@ -172,6 +224,9 @@ class BaseNode(ABC): Returns: 标准化的状态更新字典 """ + if not self.check_activate(state): + return self.trans_activate(state) + import time start_time = time.time() @@ -204,12 +259,11 @@ class BaseNode(ABC): return { **wrapped_output, "messages": state["messages"], - "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var }, "looping": state["looping"] - } + } | self.trans_activate(state) except TimeoutError: elapsed_time = time.time() - start_time @@ -220,7 +274,7 @@ class BaseNode(ABC): logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) return self._wrap_error(str(e), elapsed_time, state) - async def run_stream(self, state: WorkflowState): + async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]: """Execute node with error handling and output wrapping (streaming) This method is called by the Executor and is responsible for: @@ -241,6 +295,11 @@ class BaseNode(ABC): Yields: State updates with streaming buffer and final result """ + if not self.check_activate(state): + yield self.trans_activate(state) + logger.info(f"跳过节点{self.node_id}") + return + import time start_time = time.time() @@ -358,7 +417,6 @@ class BaseNode(ABC): state_update = { **final_output, "messages": state["messages"], - "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var }, @@ -377,7 +435,7 @@ class BaseNode(ABC): # Finally yield state update # LangGraph will merge this into state - yield state_update + yield state_update | self.trans_activate(state) except TimeoutError: elapsed_time = time.time() - start_time @@ -427,12 +485,13 @@ class BaseNode(ABC): "token_usage": token_usage, "error": None } - - return { - "node_outputs": { - self.node_id: node_output - } + final_output = { + "node_outputs": {self.node_id: node_output}, } + if self.variable_updater: + final_output = final_output | {"variables": state["variables"]} + + return final_output def _wrap_error( self, diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index c294bb11..aaf49a11 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -26,6 +26,9 @@ class NodeType(StrEnum): MEMORY_WRITE = "memory-write" +BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] + + class ComparisonOperator(StrEnum): EMPTY = "empty" NOT_EMPTY = "not_empty" diff --git a/api/app/schemas/emotion_schema.py b/api/app/schemas/emotion_schema.py index 5175fed1..cfa65b0f 100644 --- a/api/app/schemas/emotion_schema.py +++ b/api/app/schemas/emotion_schema.py @@ -11,6 +11,7 @@ class EmotionTagsRequest(BaseModel): start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)") end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)") limit: int = Field(10, ge=1, le=100, description="返回数量限制") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionWordcloudRequest(BaseModel): @@ -18,20 +19,24 @@ class EmotionWordcloudRequest(BaseModel): group_id: str = Field(..., description="组ID") emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") limit: int = Field(50, ge=1, le=200, description="返回词语数量") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionHealthRequest(BaseModel): """获取情绪健康指数请求""" group_id: str = Field(..., description="组ID") time_range: str = Field("30d", description="时间范围(7d/30d/90d)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionSuggestionsRequest(BaseModel): """获取个性化情绪建议请求""" group_id: str = Field(..., description="组ID") config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class EmotionGenerateSuggestionsRequest(BaseModel): """生成个性化情绪建议请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index c9f9146d..6f7498a0 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -44,6 +44,7 @@ class EndUserProfileResponse(BaseModel): updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None) + class EndUserProfileUpdate(BaseModel): """终端用户基本信息更新请求模型""" end_user_id: str = Field(description="终端用户ID") diff --git a/api/app/schemas/memory_episodic_schema.py b/api/app/schemas/memory_episodic_schema.py index 832bf34b..74e68837 100644 --- a/api/app/schemas/memory_episodic_schema.py +++ b/api/app/schemas/memory_episodic_schema.py @@ -51,6 +51,7 @@ class EpisodicMemoryOverviewRequest(BaseModel): """情景记忆总览查询请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") time_range: str = Field( default="all", description="时间范围筛选,可选值:all, today, this_week, this_month" @@ -70,3 +71,4 @@ class EpisodicMemoryDetailsRequest(BaseModel): end_user_id: str = Field(..., description="终端用户ID") summary_id: str = Field(..., description="情景记忆摘要ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/schemas/memory_explicit_schema.py b/api/app/schemas/memory_explicit_schema.py index c2b51a81..823a3116 100644 --- a/api/app/schemas/memory_explicit_schema.py +++ b/api/app/schemas/memory_explicit_schema.py @@ -1,15 +1,19 @@ """ 显性记忆的请求和响应模型 """ +from typing import Optional + from pydantic import BaseModel, Field class ExplicitMemoryOverviewRequest(BaseModel): """显性记忆总览查询请求""" end_user_id: str = Field(..., description="终端用户ID") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") class ExplicitMemoryDetailsRequest(BaseModel): """显性记忆详情查询请求""" end_user_id: str = Field(..., description="终端用户ID") memory_id: str = Field(..., description="记忆ID(情景记忆或语义记忆的ID)") + language_type: Optional[str] = Field("zh", description="语言类型(zh/en)") diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 2ac9ac05..68acab1d 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1445,7 +1445,7 @@ class AppService: target_workspace_ids: List[uuid.UUID], user_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None - ) -> AppShare: + ) -> list[AppShare]: """分享应用到其他工作空间 Args: diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index c9230a26..fd0cb0eb 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -26,6 +26,7 @@ from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError +from app.services.memory_base_service import Translation_English from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, @@ -692,7 +693,9 @@ class MemoryAgentService: async def get_hot_memory_tags_by_user( self, end_user_id: Optional[str] = None, - limit: int = 20 + limit: int = 20, + model_id: Optional[str] = None, + language_type: Optional[str] = "zh" ) -> List[Dict[str, Any]]: """ 获取指定用户的热门记忆标签 @@ -710,7 +713,13 @@ class MemoryAgentService: try: # by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度) tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [{"name": t, "frequency": f} for t, f in tags] + payload=[] + for tag, freq in tags: + if language_type!="zh": + tag=await Translation_English(model_id, tag) + payload.append({"name": tag, "frequency": freq}) + else: + payload.append({"name": tag, "frequency": freq}) return payload except Exception as e: logger.error(f"热门记忆标签查询失败: {e}") diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index 6f844ae9..25a8281d 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -3,17 +3,268 @@ Memory Base Service 提供记忆服务的基础功能和共享辅助方法。 """ - +import asyncio +import re from datetime import datetime from typing import Optional - +from pydantic import BaseModel from app.core.logging_config import get_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.emotion_analytics_service import EmotionAnalyticsService - +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.models.base import RedBearModelConfig +from app.services.memory_config_service import MemoryConfigService +from app.db import get_db_context logger = get_logger(__name__) +class TranslationResponse(BaseModel): + """翻译响应模型""" + data: str + +class MemoryTransService: + """记忆翻译服务,提供中英文翻译功能""" + + def __init__(self, llm_client=None, model_id: Optional[str] = None): + """ + 初始化翻译服务 + + Args: + llm_client: LLM客户端实例或模型ID字符串(可选) + model_id: 模型ID,用于初始化LLM客户端(可选) + + Note: + - 如果llm_client是字符串,会被当作model_id使用 + - 如果同时提供llm_client和model_id,优先使用llm_client + """ + # 处理llm_client参数:如果是字符串,当作model_id + if isinstance(llm_client, str): + self.model_id = llm_client + self.llm_client = None + else: + self.llm_client = llm_client + self.model_id = model_id + + self._initialized = False + + def _ensure_llm_client(self): + """确保LLM客户端已初始化""" + if self._initialized: + return + + if self.llm_client is None: + if self.model_id: + with get_db_context() as db: + config_service = MemoryConfigService(db) + model_config = config_service.get_model_config(self.model_id) + + extra_params = { + "temperature": 0.2, + "max_tokens": 400, + "top_p": 0.8, + "stream": False, + } + + self.llm_client = OpenAIClient( + RedBearModelConfig( + model_name=model_config.get("model_name"), + provider=model_config.get("provider"), + api_key=model_config.get("api_key"), + base_url=model_config.get("base_url"), + timeout=model_config.get("timeout", 30), + max_retries=model_config.get("max_retries", 3), + extra_params=extra_params + ), + type_=model_config.get("type") + ) + else: + raise ValueError("必须提供 llm_client 或 model_id 之一") + + self._initialized = True + + async def translate_to_english(self, text: str) -> str: + """ + 将中文翻译为英文 + + Args: + text: 要翻译的中文文本 + + Returns: + 翻译后的英文文本 + """ + self._ensure_llm_client() + + translation_messages = [ + { + "role": "user", + "content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}" + } + ] + + try: + response = await self.llm_client.response_structured( + messages=translation_messages, + response_model=TranslationResponse + ) + return response.data + except Exception as e: + logger.error(f"翻译失败: {str(e)}") + return text # 翻译失败时返回原文 + + async def is_english(self, text: str) -> bool: + """ + 检查文本是否为英文 + + Args: + text: 要检查的文本(必须是字符串) + + Returns: + True 如果文本主要是英文,False 否则 + + Note: + - 只接受字符串类型 + - 检查是否主要由英文字母和常见标点组成 + - 允许数字、空格和常见标点符号 + """ + if not isinstance(text, str): + raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}") + + if not text.strip(): + return True # 空字符串视为英文 + + # 更宽松的英文检查:允许字母、数字、空格和常见标点 + # 如果文本中英文字符占比超过 80%,认为是英文 + english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-')) + total_chars = len(text) + + if total_chars == 0: + return True + + return (english_chars / total_chars) >= 0.8 + async def Translate(self, text: str, target_language: str = "en") -> str: + """ + 通用翻译方法(保持向后兼容) + + Args: + text: 要翻译的文本 + target_language: 目标语言,"en"表示英文,"zh"表示中文 + + Returns: + 翻译后的文本 + """ + if target_language == "en": + return await self.translate_to_english(text) + else: + logger.warning(f"不支持的目标语言: {target_language},返回原文") + return text + + # 测试翻译服务 +async def Translation_English(modid, text, fields=None): + """ + 将数据翻译为英文(支持字段级翻译) + + Args: + modid: 模型ID + text: 要翻译的数据(可以是字符串、字典或列表) + fields: 需要翻译的字段列表(可选) + 如果为None,默认翻译: ['content', 'summary', 'statement', 'description', + 'name', 'aliases', 'caption', 'emotion_keywords'] + + Returns: + 翻译后的数据,保持原有结构 + + Note: + - 对于字符串:直接翻译 + - 对于列表:递归处理每个元素,保持列表长度和索引不变 + - 对于字典:只翻译指定字段(fields参数) + - 对于其他类型:原样返回 + """ + trans_service = MemoryTransService(modid) + + # 处理字符串类型 + if isinstance(text, str): + # 空字符串直接返回 + if not text.strip(): + return text + + try: + is_eng = await trans_service.is_english(text) + if not is_eng: + english_result = await trans_service.Translate(text) + return english_result + return text + except Exception as e: + logger.warning(f"翻译字符串失败: {e}") + return text + + # 处理列表类型 + elif isinstance(text, list): + english_result = [] + for item in text: + # 递归处理列表中的每个元素 + if isinstance(item, str): + # 字符串元素:检查是否需要翻译 + if not item.strip(): + english_result.append(item) + continue + + try: + is_eng = await trans_service.is_english(item) + if not is_eng: + translated = await trans_service.Translate(item) + english_result.append(translated) + else: + # 保留英文项,不改变列表长度 + english_result.append(item) + except Exception as e: + logger.warning(f"翻译列表项失败: {e}") + english_result.append(item) + + elif isinstance(item, dict): + # 字典元素:递归调用自己处理字典 + translated_dict = await Translation_English(modid, item, fields) + english_result.append(translated_dict) + + elif isinstance(item, list): + # 嵌套列表:递归处理 + translated_list = await Translation_English(modid, item, fields) + english_result.append(translated_list) + + else: + # 其他类型(数字、布尔值等):原样保留 + english_result.append(item) + + return english_result + + # 处理字典类型 + elif isinstance(text, dict): + # 确定要翻译的字段 + if fields is None: + # 默认翻译字段 + fields = [ + 'content', 'summary', 'statement', 'description', + 'name', 'aliases', 'caption', 'emotion_keywords', + 'text', 'title', 'label', 'type' # 添加常用字段 + ] + + # 创建副本,避免修改原始数据 + result = text.copy() + + for field in fields: + if field in result and result[field] is not None: + # 递归翻译字段值(可能是字符串、列表或嵌套字典) + try: + result[field] = await Translation_English(modid, result[field], fields) + except Exception as e: + logger.warning(f"翻译字段 {field} 失败: {e}") + # 翻译失败时保留原值 + continue + + return result + + # 其他类型(数字、布尔值、None等):原样返回 + else: + return text class MemoryBaseService: """记忆服务基类,提供共享的辅助方法""" @@ -294,4 +545,4 @@ class MemoryBaseService: except Exception as e: logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True) - return 0 + return 0 \ No newline at end of file diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py index eedb7c29..9b5f3c99 100644 --- a/api/app/services/memory_entity_relationship_service.py +++ b/api/app/services/memory_entity_relationship_service.py @@ -16,6 +16,7 @@ import json from datetime import datetime from app.schemas.memory_episodic_schema import EmotionType +from app.services.memory_base_service import Translation_English logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ class MemoryEntityService: self.id = id self.table = table self.connector = Neo4jConnector() - async def get_timeline_memories_server(self): + async def get_timeline_memories_server(self,model_id, language_type): """ 获取时间线记忆数据 @@ -48,10 +49,10 @@ class MemoryEntityService: logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}") # 根据表类型选择查询 - if self.table == 'Statement': + if self.table == 'Statement': # Statement只需要输入ID,使用简化查询 results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id) - elif self.table == 'ExtractedEntity': + elif self.table == 'ExtractedEntity': # ExtractedEntity类型查询 results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id) else: @@ -62,7 +63,7 @@ class MemoryEntityService: logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}") # 处理查询结果 - timeline_data = self._process_timeline_results(results) + timeline_data =await self._process_timeline_results(results, model_id, language_type) logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条") @@ -71,12 +72,14 @@ class MemoryEntityService: except Exception as e: logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True) return str(e) - def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: + async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]: """ 处理时间线查询结果 Args: results: Neo4j查询结果 + model_id: 模型ID用于翻译 + language_type: 语言类型 ('zh' 或其他) Returns: 处理后的时间线数据字典 @@ -104,19 +107,19 @@ class MemoryEntityService: # 处理MemorySummary summary = data.get('MemorySummary') if summary is not None: - processed_summary = self._process_field_value(summary, "MemorySummary") + processed_summary = await self._process_field_value(summary, "MemorySummary") memory_summary_list.extend(processed_summary) # 处理Statement statement = data.get('statement') if statement is not None: - processed_statement = self._process_field_value(statement, "Statement") + processed_statement = await self._process_field_value(statement, "Statement") statement_list.extend(processed_statement) # 处理ExtractedEntity extracted_entity = data.get('ExtractedEntity') if extracted_entity is not None: - processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity") + processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity") extracted_entity_list.extend(processed_entity) # 去重 - 现在处理的是字典列表,需要更智能的去重 @@ -128,6 +131,8 @@ class MemoryEntityService: all_timeline_data = memory_summary_list + statement_list all_timeline_data = self._merge_same_text_items(all_timeline_data) + # 如果需要翻译(非中文),对整个结果进行翻译 + result = { "MemorySummary": memory_summary_list, "Statement": statement_list, @@ -233,7 +238,7 @@ class MemoryEntityService: except Exception: return False - def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]: + async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]: """ 处理字段值,支持字符串、列表等类型 @@ -251,13 +256,13 @@ class MemoryEntityService: # 如果是列表,处理每个元素 for item in value: if self._is_valid_item(item): - processed_item = self._process_single_item(item) + processed_item = await self._process_single_item(item) if processed_item: processed_values.append(processed_item) elif isinstance(value, dict): # 如果是字典,直接处理 if self._is_valid_item(value): - processed_item = self._process_single_item(value) + processed_item = await self._process_single_item(value) if processed_item: processed_values.append(processed_item) elif isinstance(value, str): @@ -304,7 +309,7 @@ class MemoryEntityService: return (str(item).strip() != '' and "MemorySummaryChunk" not in str(item)) - def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]: + async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ 处理单个项目 @@ -369,6 +374,117 @@ class MemoryEntityService: logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}") return str(dt) if dt is not None else None + async def _translate_list( + self, + data_list: List[Dict[str, Any]], + model_id: str, + fields: List[str] + ) -> List[Dict[str, Any]]: + """ + 翻译列表中每个字典的指定字段(并发有限度以降低整体延迟) + + Args: + data_list: 要翻译的字典列表 + model_id: 模型ID + fields: 需要翻译的字段列表 + + Returns: + 翻译后的字典列表 + """ + # 空列表或无字段时直接返回 + if not data_list or not fields: + return data_list + + import asyncio + + # 并发限制,避免一次性发起过多请求 + # 可根据实际情况调整(建议 5-10) + concurrency_limit = 5 + semaphore = asyncio.Semaphore(concurrency_limit) + + async def translate_single_field( + index: int, + field: str, + value: Any, + ) -> Optional[tuple]: + """ + 翻译单个字段并返回 (索引, 字段名, 翻译结果) + + Returns: + (index, field, translated_value) 或 None(如果跳过) + """ + # 跳过空值 + if value is None or value == "": + return None + + # 统一转成字符串再翻译,防止非字符串类型导致错误 + text = str(value) + + try: + async with semaphore: + # 调用 Translation_English 进行翻译 + # 注意:Translation_English 的参数顺序是 (model_id, text) + translated = await Translation_English(model_id, text) + + # 如果翻译结果为空,保留原值 + if translated is None or translated == "": + return None + + return index, field, translated + except Exception as e: + logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}") + return None + + # 构造所有需要翻译的任务 + tasks = [] + for idx, item in enumerate(data_list): + # 防御性检查:确保 item 是字典 + if not isinstance(item, dict): + continue + + for field in fields: + if field not in item: + continue + + value = item.get(field) + + # 对于 None 或空字符串的值,直接跳过,不创建任务 + if value is None or value == "": + continue + + tasks.append( + asyncio.create_task( + translate_single_field(idx, field, value) + ) + ) + + # 如果没有需要翻译的任务,直接返回原列表 + if not tasks: + return data_list + + # 使用 gather 并发执行翻译任务(受 semaphore 限制) + # return_exceptions=True 可以防止单个任务失败导致整体失败 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 创建深拷贝以避免修改原始数据 + translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list] + + # 将翻译结果回填到列表 + for result in results: + # 跳过 None 结果和异常 + if result is None or isinstance(result, Exception): + if isinstance(result, Exception): + logger.warning(f"翻译任务异常: {result}") + continue + + idx, field, translated = result + + # 防御性检查索引范围 + if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict): + translated_list[idx][field] = translated + + return translated_list + @@ -426,15 +542,19 @@ class MemoryEmotion: # 如果解析失败,返回原始字符串 return iso_string - async def get_emotion(self) -> Dict[str, Any]: + async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]: """ 获取情绪随时间变化数据 + Args: + model_id: 模型ID用于翻译 + language_type: 语言类型 ('zh' 或其他) + Returns: 包含情绪数据的字典 """ try: - logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}") + logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}") if self.table == 'Statement': results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id) @@ -450,6 +570,10 @@ class MemoryEmotion: # 转换Neo4j类型 final_data = self._convert_neo4j_types(emotion_data) + # 如果需要翻译(非中文) + if language_type != 'zh' and model_id and final_data: + final_data = await self._translate_emotion_data(final_data, model_id) + logger.info(f"成功获取 {len(final_data)} 条情绪数据") return final_data @@ -590,16 +714,14 @@ class MemoryInteraction: """ try: logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}") - ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) if ori_data!=[]: # name = ori_data[0]['name'] - group_id = ori_data[0]['group_id'] + group_id = [i['group_id'] for i in ori_data][0] Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) if not Space_User: return [] user_id=Space_User[0]['id'] - results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 9221ab06..ae07256a 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.services.implicit_memory_service import ImplicitMemoryService -from app.services.memory_base_service import MemoryBaseService +from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -360,7 +360,9 @@ class UserMemoryService: async def get_cached_memory_insight( self, db: Session, - end_user_id: str + end_user_id: str, + model_id: str, + language_type: str ) -> Dict[str, Any]: """ 从数据库获取缓存的记忆洞察(四个维度) @@ -419,11 +421,18 @@ class UserMemoryService: key_findings_array = [] logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)") + memory_insight=end_user.memory_insight + behavior_pattern=end_user.behavior_pattern + growth_trajectory=end_user.growth_trajectory + if language_type!='zh': + memory_insight=await Translation_English(model_id,memory_insight) + behavior_pattern=await Translation_English(model_id,behavior_pattern) + growth_trajectory=await Translation_English(model_id,growth_trajectory) return { - "memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight - "behavior_pattern": end_user.behavior_pattern, + "memory_insight":memory_insight, # 总体概述存储在 memory_insight + "behavior_pattern":behavior_pattern, "key_findings": key_findings_array, # 返回数组 - "growth_trajectory": end_user.growth_trajectory, + "growth_trajectory": growth_trajectory, "updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at), "is_cached": True } @@ -457,7 +466,9 @@ class UserMemoryService: async def get_cached_user_summary( self, db: Session, - end_user_id: str + end_user_id: str, + model_id:str, + language_type:str="zh" ) -> Dict[str, Any]: """ 从数据库获取缓存的用户摘要(四个部分) @@ -481,7 +492,6 @@ class UserMemoryService: user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) - if not end_user: logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") return { @@ -495,20 +505,29 @@ class UserMemoryService: } # 检查是否有缓存数据(至少有一个字段不为空) + user_summary=end_user.user_summary + personality_traits=end_user.personality_traits + core_values=end_user.core_values + one_sentence_summary=end_user.one_sentence_summary + if language_type!='zh': + user_summary=await Translation_English(model_id, user_summary) + personality_traits = await Translation_English(model_id, personality_traits) + core_values = await Translation_English(model_id, core_values) + one_sentence_summary = await Translation_English(model_id, one_sentence_summary) has_cache = any([ - end_user.user_summary, - end_user.personality_traits, - end_user.core_values, - end_user.one_sentence_summary + user_summary, + personality_traits, + core_values, + one_sentence_summary ]) if has_cache: logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要") return { - "user_summary": end_user.user_summary, - "personality": end_user.personality_traits, - "core_values": end_user.core_values, - "one_sentence": end_user.one_sentence_summary, + "user_summary": user_summary, + "personality": personality_traits, + "core_values":core_values, + "one_sentence": one_sentence_summary, "updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at), "is_cached": True } @@ -1367,7 +1386,6 @@ async def analytics_memory_types( return memory_types - async def analytics_graph_data( db: Session, end_user_id: str, @@ -1557,7 +1575,7 @@ async def analytics_graph_data( f"成功获取图数据: end_user_id={end_user_id}, " f"nodes={len(nodes)}, edges={len(edges)}" ) - + return { "nodes": nodes, "edges": edges, @@ -1606,11 +1624,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ # 获取该节点类型的白名单字段 allowed_fields = field_whitelist.get(label, []) - - # 如果没有定义白名单,返回空字典(或者可以返回所有字段) - # if not allowed_fields: - # # 对于未定义的节点类型,只返回基本字段 - # allowed_fields = ["name", "created_at", "caption"] + count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;""" node_results = await (_neo4j_connector.execute_query(count_neo4j)) # 提取白名单中的字段 @@ -1618,13 +1632,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ for field in allowed_fields: if field in properties: value = properties[field] - if str(field) == 'entity_type': + if str(field) == 'entity_type': value=type_mapping.get(value,'') if str(field)=="emotion_type": value=EmotionType.EMOTION_MAPPING.get(value) - if str(field)=="emotion_subject": + if str(field)=="emotion_subject": value=EmotionSubject.SUBJECT_MAPPING.get(value) - # 清理 Neo4j 特殊类型 filtered_props[field] = _clean_neo4j_value(value) filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0] return filtered_props