Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
lixinyue
2026-01-21 11:53:25 +08:00
18 changed files with 807 additions and 239 deletions

View File

@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
@@ -661,6 +661,11 @@ async def draft_run(
data=result,
msg="工作流任务执行成功"
)
else:
return fail(
msg="未知应用类型",
code=422
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")

View File

@@ -9,7 +9,7 @@ from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
@@ -616,8 +616,10 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
language_type: Optional[str] ="zh",
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session=Depends(get_db),
):
"""
获取指定用户的热门记忆标签
@@ -628,10 +630,22 @@ async def get_hot_memory_tags_by_user_api(
...
]
"""
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
language_type=language_type,
model_id=model_id,
limit=limit
)
return success(data=result, msg="获取热门记忆标签成功")

View File

@@ -20,6 +20,7 @@ router = APIRouter(
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
language_type:Optional[str] = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):

View File

@@ -12,6 +12,7 @@ from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.memory_base_service import Translation_English
from app.services.user_memory_service import (
UserMemoryService,
analytics_memory_types,
@@ -20,7 +21,7 @@ from app.services.user_memory_service import (
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -44,6 +45,7 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: str,
language_type: str = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -53,10 +55,18 @@ async def get_memory_insight_report_api(
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
如需生成新的洞察报告,请使用专门的生成接口。
"""
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
@@ -72,6 +82,7 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str,
language_type: str="zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -81,10 +92,18 @@ async def get_user_summary_api(
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
"""
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -253,7 +272,6 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
@@ -278,7 +296,13 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
@@ -296,7 +320,6 @@ async def get_end_user_profile(
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -396,12 +419,21 @@ async def update_end_user_profile(
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,

View File

@@ -54,6 +54,8 @@ class WorkflowExecutor:
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
self.start_node_id = None
self.checkpoint_config = RunnableConfig(
configurable={
"thread_id": uuid.uuid4(),
@@ -131,77 +133,12 @@ class WorkflowExecutor:
for node in self.workflow_config.get("nodes")
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
], # loop, iteration node id
"looping": False # loop runing flag, only use in loop node,not use in main loop
"looping": False, # loop runing flag, only use in loop node,not use in main loop
"activate": {
self.start_node_id: True
}
}
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def _build_final_output(self, result, elapsed_time):
node_outputs = result.get("node_outputs", {})
final_output = self._extract_final_output(node_outputs)
@@ -231,10 +168,12 @@ class WorkflowExecutor:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
graph = GraphBuilder(
builder = GraphBuilder(
self.workflow_config,
stream=stream,
).build()
)
self.start_node_id = builder.start_node_id
graph = builder.build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
@@ -375,13 +314,15 @@ class WorkflowExecutor:
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task":
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
if not inputv.get("activate", {}).get(node_name):
continue
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[NODE-START] Node starts execution: {node_name} "
f"- execution_id: {self.execution_id}")
@@ -390,18 +331,17 @@ class WorkflowExecutor:
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
}
}
elif event_type == "task_result":
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
if not result.get("activate", {}).get(node_name):
continue
conversation_id = input_data.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[NODE-END] Node execution completed: {node_name} "
f"- execution_id: {self.execution_id}")
@@ -410,7 +350,7 @@ class WorkflowExecutor:
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"execution_id": self.execution_id,
"timestamp": data.get("timestamp"),
"state": result.get("node_outputs", {}).get(node_name),
}

View File

@@ -1,14 +1,16 @@
import logging
import uuid
from collections import defaultdict
from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Send
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
logger = logging.getLogger(__name__)
@@ -28,7 +30,10 @@ class GraphBuilder:
self.start_node_id = None
self.end_node_ids = []
self.graph: StateGraph | CompiledStateGraph | None = None
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property
def nodes(self) -> list[dict[str, Any]]:
@@ -39,74 +44,98 @@ class GraphBuilder:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
"""
Analyze the prefix configuration for End nodes.
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
This function scans each End node's output template, identifies
references to its direct upstream nodes, and extracts the prefix
string appearing before the first reference.
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
tuple:
- dict[str, str]: Mapping from upstream node ID to its End node prefix
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} End 节点")
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
if not output_template:
continue
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} {{ node_id.xxx }} 格式(支持空格)
# Find all node references in the template
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
# 找到所有直接连接到 End 节点的上游节点
# Identify all direct upstream nodes connected to the End node
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
logger.info(f"[Prefix Analysis] "
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
logger.info(f"[Prefix Analysis] "
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def add_nodes(self):
"""Add all nodes from the workflow configuration to the state graph.
This method handles:
- Creation of node instances using NodeFactory.
- Special handling for start, end, and cycle nodes.
- Injection of End node prefixes for streaming mode.
- Marking nodes as adjacent to End nodes if referenced.
- Wrapping node run methods as async functions or async generators
depending on streaming mode.
Notes:
Loop nodes (nodes with `cycle` property) are handled separately
via CycleGraphNode when building subgraphs.
Returns:
None
"""
# Analyze End node prefixes if in stream mode
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
@@ -114,21 +143,21 @@ class GraphBuilder:
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# 记录 start end 节点 ID
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# 创建节点实例(现在 start end 也会被创建)
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
if node_type in BRANCH_NODES:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
@@ -142,26 +171,23 @@ class GraphBuilder:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
# Inject End node prefix configuration if in stream mode
if self.stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
logger.info(f"Injected End prefix for node {node_id}")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
# Mark nodes as adjacent and referenced to End node in stream mode
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
# Wrap node's run method to avoid closure issues
if self.stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
# Stream mode: create an async generator function
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
@@ -169,7 +195,7 @@ class GraphBuilder:
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
# Non-stream mode: create an async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
@@ -178,45 +204,110 @@ class GraphBuilder:
self.graph.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
logger.debug(f"Added node: {node_id} (type={node_type}, stream={self.stream})")
def add_edges(self):
"""Add all edges (normal, waiting, and conditional) to the state graph.
This method handles:
- Connecting the START node to the workflow's start node.
- Collecting waiting edges for nodes with multiple sources.
- Collecting conditional edges for routing to NOP nodes.
- Adding NOP nodes for conditional branches to allow later merging.
- Wrapping routing logic in a router function that evaluates conditions.
- Connecting End nodes to the global END node.
Notes:
- NOP nodes are used to ensure that multiple branches can merge
correctly without modifying the workflow state.
- Waiting edges are automatically handled by LangGraph to schedule
nodes only after all sources are activated.
Returns:
None
"""
# Connect the START node to the workflow's start node
if self.start_node_id:
self.graph.add_edge(START, self.start_node_id)
logger.debug(f"添加边: START -> {self.start_node_id}")
logger.debug(f"Added edge: START -> {self.start_node_id}")
# Collect all sources for each target node for normal/waiting edges
waiting_edges = defaultdict(list)
# Collect all conditional edges for each source node to construct routing
conditional_edges = defaultdict(list)
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
edge_type = edge.get("type")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
# Skip error edges (handled within nodes)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
# Conditional edges: group by source node
conditional_edges[source].append({
"target": target,
"condition": condition,
"label": edge.get("label")
})
else:
# Normal edges: group by target node (used for waiting edges)
waiting_edges[target].append(source)
def router_fn(state: WorkflowState):
# Add conditional edges
for source_node, branches in conditional_edges.items():
def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets):
def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE
return {
"activate": {
node_id: s["activate"][node_name]
for node_id in targets
}
}
return node
unique_branch = {}
for branch in branch_list:
if branch.get("label") not in unique_branch.keys():
nop_node_name = f"nop_{uuid.uuid4().hex[:8]}"
logger.info(f"Binding NOP: {source_node} {branch.get('label')} -> {nop_node_name}")
unique_branch[branch["label"]] = {
"condition": branch["condition"],
"node": {
"name": nop_node_name,
},
"target": [branch["target"]]
}
else:
unique_branch[branch["label"]]["target"].append(branch["target"])
# Add NOP nodes and connect them to downstream nodes
for label, branch_info in unique_branch.items():
self.graph.add_node(
branch_info["node"]["name"],
make_branch_node(
branch_info["node"]["name"],
branch_info["target"]
)
)
for target in branch_info["target"]:
waiting_edges[target].append(branch_info["node"]["name"])
def router_fn(state: WorkflowState) -> list[Send]:
branch_activate = []
new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
for label, branch in unique_branch.items():
if evaluate_condition(
cond,
branch["condition"],
state.get("variables", {}),
state.get("runtime_vars", {}),
{
@@ -225,30 +316,45 @@ class GraphBuilder:
"user_id": state.get("user_id")
}
):
return tgt
return END
logger.debug(f"Conditional routing {src}: selected branch {label}")
new_state["activate"][branch["node"]["name"]] = True
continue
new_state["activate"][branch["node"]["name"]] = False
for label, branch in unique_branch.items():
branch_activate.append(
Send(
branch['node']['name'],
new_state
)
)
return branch_activate
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
return router_fn
# Dynamically set function name
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{src}"
return router_fn
router_fn = make_router(condition, target)
self.graph.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
router_fn = make_router(source_node, branches)
self.graph.add_conditional_edges(source_node, router_fn)
logger.debug(f"Added conditional edges: {source_node} -> {[b['target'] for b in branches]}")
# Add normal/waiting edges
for target, sources in waiting_edges.items():
if len(sources) == 1:
# Single source: normal edge
self.graph.add_edge(sources[0], target)
logger.debug(f"Added edge: {sources[0]} -> {target}")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# Multiple sources: waiting edge
self.graph.add_edge(sources, target)
logger.debug(f"Added waiting edge: {sources} -> {target}")
# 从 end 节点连接到 END
# Connect End nodes to the global END node
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
logger.debug(f"Added edge: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer)
self.graph = self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:

View File

@@ -7,18 +7,26 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, AsyncGenerator
from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
def merget_activate_state(x, y):
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
class WorkflowState(TypedDict):
"""Workflow state
@@ -60,6 +68,9 @@ class WorkflowState(TypedDict):
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# node activate status
activate: Annotated[dict[str, bool], merget_activate_state]
class BaseNode(ABC):
"""节点基类
@@ -84,6 +95,47 @@ class BaseNode(ABC):
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
self.variable_updater = False
def check_activate(self, state: WorkflowState):
"""Check if the current node is activated in the workflow state.
Args:
state (WorkflowState): The current workflow state containing the 'activate' dict.
Returns:
bool: True if the node is activated, False otherwise.
"""
return state["activate"][self.node_id]
def trans_activate(self, state: WorkflowState):
"""Transform the activation state for downstream nodes.
This method collects all downstream nodes (excluding branch nodes)
connected to the current node and returns a dict indicating whether
each of these nodes should be activated based on the current node's state.
The current node itself is also included in the returned activation dict.
Args:
state (WorkflowState): The current workflow state.
Returns:
dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False).
"""
edges = self.workflow_config.get("edges")
under_stream_nodes = [
edge.get("target")
for edge in edges
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
]
return {
"activate": {
node_id: self.check_activate(state)
for node_id in under_stream_nodes
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
@@ -99,13 +151,13 @@ class BaseNode(ABC):
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
>>> {"message": "开始", "conversation_id": "xxx"}
"""
pass
@@ -126,14 +178,14 @@ class BaseNode(ABC):
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
# 流式 LLM 节点
full_response = ""
async for chunk in llm.astream(prompt):
full_response += chunk
yield chunk # yield 文本片段
# 最后 yield 完成标记
yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
@@ -146,7 +198,7 @@ class BaseNode(ABC):
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
return self.__class__.execute_stream is not BaseNode.execute_stream
def get_timeout(self) -> int:
"""获取超时时间(秒)
@@ -172,6 +224,9 @@ class BaseNode(ABC):
Returns:
标准化的状态更新字典
"""
if not self.check_activate(state):
return self.trans_activate(state)
import time
start_time = time.time()
@@ -204,12 +259,11 @@ class BaseNode(ABC):
return {
**wrapped_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
}
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
@@ -220,7 +274,7 @@ class BaseNode(ABC):
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
"""Execute node with error handling and output wrapping (streaming)
This method is called by the Executor and is responsible for:
@@ -241,6 +295,11 @@ class BaseNode(ABC):
Yields:
State updates with streaming buffer and final result
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"跳过节点{self.node_id}")
return
import time
start_time = time.time()
@@ -358,7 +417,6 @@ class BaseNode(ABC):
state_update = {
**final_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
},
@@ -377,7 +435,7 @@ class BaseNode(ABC):
# Finally yield state update
# LangGraph will merge this into state
yield state_update
yield state_update | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
@@ -427,12 +485,13 @@ class BaseNode(ABC):
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
final_output = {
"node_outputs": {self.node_id: node_output},
}
if self.variable_updater:
final_output = final_output | {"variables": state["variables"]}
return final_output
def _wrap_error(
self,

View File

@@ -26,6 +26,9 @@ class NodeType(StrEnum):
MEMORY_WRITE = "memory-write"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
class ComparisonOperator(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"

View File

@@ -11,6 +11,7 @@ class EmotionTagsRequest(BaseModel):
start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01")
end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31")
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionWordcloudRequest(BaseModel):
@@ -18,20 +19,24 @@ class EmotionWordcloudRequest(BaseModel):
group_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionHealthRequest(BaseModel):
"""获取情绪健康指数请求"""
group_id: str = Field(..., description="组ID")
time_range: str = Field("30d", description="时间范围7d/30d/90d")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求"""
group_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class EmotionGenerateSuggestionsRequest(BaseModel):
"""生成个性化情绪建议请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -44,6 +44,7 @@ class EndUserProfileResponse(BaseModel):
updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None)
class EndUserProfileUpdate(BaseModel):
"""终端用户基本信息更新请求模型"""
end_user_id: str = Field(description="终端用户ID")

View File

@@ -51,6 +51,7 @@ class EpisodicMemoryOverviewRequest(BaseModel):
"""情景记忆总览查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
time_range: str = Field(
default="all",
description="时间范围筛选可选值all, today, this_week, this_month"
@@ -70,3 +71,4 @@ class EpisodicMemoryDetailsRequest(BaseModel):
end_user_id: str = Field(..., description="终端用户ID")
summary_id: str = Field(..., description="情景记忆摘要ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -1,15 +1,19 @@
"""
显性记忆的请求和响应模型
"""
from typing import Optional
from pydantic import BaseModel, Field
class ExplicitMemoryOverviewRequest(BaseModel):
"""显性记忆总览查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")
class ExplicitMemoryDetailsRequest(BaseModel):
"""显性记忆详情查询请求"""
end_user_id: str = Field(..., description="终端用户ID")
memory_id: str = Field(..., description="记忆ID情景记忆或语义记忆的ID")
language_type: Optional[str] = Field("zh", description="语言类型zh/en")

View File

@@ -1445,7 +1445,7 @@ class AppService:
target_workspace_ids: List[uuid.UUID],
user_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
) -> AppShare:
) -> list[AppShare]:
"""分享应用到其他工作空间
Args:

View File

@@ -26,6 +26,7 @@ from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_base_service import Translation_English
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
@@ -692,7 +693,9 @@ class MemoryAgentService:
async def get_hot_memory_tags_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 20
limit: int = 20,
model_id: Optional[str] = None,
language_type: Optional[str] = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的热门记忆标签
@@ -710,7 +713,13 @@ class MemoryAgentService:
try:
# by_user=False 表示按 group_id 查询在Neo4j中group_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload = [{"name": t, "frequency": f} for t, f in tags]
payload=[]
for tag, freq in tags:
if language_type!="zh":
tag=await Translation_English(model_id, tag)
payload.append({"name": tag, "frequency": freq})
else:
payload.append({"name": tag, "frequency": freq})
return payload
except Exception as e:
logger.error(f"热门记忆标签查询失败: {e}")

View File

@@ -3,17 +3,268 @@ Memory Base Service
提供记忆服务的基础功能和共享辅助方法。
"""
import asyncio
import re
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
from app.db import get_db_context
logger = get_logger(__name__)
class TranslationResponse(BaseModel):
"""翻译响应模型"""
data: str
class MemoryTransService:
"""记忆翻译服务,提供中英文翻译功能"""
def __init__(self, llm_client=None, model_id: Optional[str] = None):
"""
初始化翻译服务
Args:
llm_client: LLM客户端实例或模型ID字符串可选
model_id: 模型ID用于初始化LLM客户端可选
Note:
- 如果llm_client是字符串会被当作model_id使用
- 如果同时提供llm_client和model_id优先使用llm_client
"""
# 处理llm_client参数如果是字符串当作model_id
if isinstance(llm_client, str):
self.model_id = llm_client
self.llm_client = None
else:
self.llm_client = llm_client
self.model_id = model_id
self._initialized = False
def _ensure_llm_client(self):
"""确保LLM客户端已初始化"""
if self._initialized:
return
if self.llm_client is None:
if self.model_id:
with get_db_context() as db:
config_service = MemoryConfigService(db)
model_config = config_service.get_model_config(self.model_id)
extra_params = {
"temperature": 0.2,
"max_tokens": 400,
"top_p": 0.8,
"stream": False,
}
self.llm_client = OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url"),
timeout=model_config.get("timeout", 30),
max_retries=model_config.get("max_retries", 3),
extra_params=extra_params
),
type_=model_config.get("type")
)
else:
raise ValueError("必须提供 llm_client 或 model_id 之一")
self._initialized = True
async def translate_to_english(self, text: str) -> str:
"""
将中文翻译为英文
Args:
text: 要翻译的中文文本
Returns:
翻译后的英文文本
"""
self._ensure_llm_client()
translation_messages = [
{
"role": "user",
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
}
]
try:
response = await self.llm_client.response_structured(
messages=translation_messages,
response_model=TranslationResponse
)
return response.data
except Exception as e:
logger.error(f"翻译失败: {str(e)}")
return text # 翻译失败时返回原文
async def is_english(self, text: str) -> bool:
"""
检查文本是否为英文
Args:
text: 要检查的文本(必须是字符串)
Returns:
True 如果文本主要是英文False 否则
Note:
- 只接受字符串类型
- 检查是否主要由英文字母和常见标点组成
- 允许数字、空格和常见标点符号
"""
if not isinstance(text, str):
raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}")
if not text.strip():
return True # 空字符串视为英文
# 更宽松的英文检查:允许字母、数字、空格和常见标点
# 如果文本中英文字符占比超过 80%,认为是英文
english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-'))
total_chars = len(text)
if total_chars == 0:
return True
return (english_chars / total_chars) >= 0.8
async def Translate(self, text: str, target_language: str = "en") -> str:
"""
通用翻译方法(保持向后兼容)
Args:
text: 要翻译的文本
target_language: 目标语言,"en"表示英文,"zh"表示中文
Returns:
翻译后的文本
"""
if target_language == "en":
return await self.translate_to_english(text)
else:
logger.warning(f"不支持的目标语言: {target_language},返回原文")
return text
# 测试翻译服务
async def Translation_English(modid, text, fields=None):
"""
将数据翻译为英文(支持字段级翻译)
Args:
modid: 模型ID
text: 要翻译的数据(可以是字符串、字典或列表)
fields: 需要翻译的字段列表(可选)
如果为None默认翻译: ['content', 'summary', 'statement', 'description',
'name', 'aliases', 'caption', 'emotion_keywords']
Returns:
翻译后的数据,保持原有结构
Note:
- 对于字符串:直接翻译
- 对于列表:递归处理每个元素,保持列表长度和索引不变
- 对于字典只翻译指定字段fields参数
- 对于其他类型:原样返回
"""
trans_service = MemoryTransService(modid)
# 处理字符串类型
if isinstance(text, str):
# 空字符串直接返回
if not text.strip():
return text
try:
is_eng = await trans_service.is_english(text)
if not is_eng:
english_result = await trans_service.Translate(text)
return english_result
return text
except Exception as e:
logger.warning(f"翻译字符串失败: {e}")
return text
# 处理列表类型
elif isinstance(text, list):
english_result = []
for item in text:
# 递归处理列表中的每个元素
if isinstance(item, str):
# 字符串元素:检查是否需要翻译
if not item.strip():
english_result.append(item)
continue
try:
is_eng = await trans_service.is_english(item)
if not is_eng:
translated = await trans_service.Translate(item)
english_result.append(translated)
else:
# 保留英文项,不改变列表长度
english_result.append(item)
except Exception as e:
logger.warning(f"翻译列表项失败: {e}")
english_result.append(item)
elif isinstance(item, dict):
# 字典元素:递归调用自己处理字典
translated_dict = await Translation_English(modid, item, fields)
english_result.append(translated_dict)
elif isinstance(item, list):
# 嵌套列表:递归处理
translated_list = await Translation_English(modid, item, fields)
english_result.append(translated_list)
else:
# 其他类型(数字、布尔值等):原样保留
english_result.append(item)
return english_result
# 处理字典类型
elif isinstance(text, dict):
# 确定要翻译的字段
if fields is None:
# 默认翻译字段
fields = [
'content', 'summary', 'statement', 'description',
'name', 'aliases', 'caption', 'emotion_keywords',
'text', 'title', 'label', 'type' # 添加常用字段
]
# 创建副本,避免修改原始数据
result = text.copy()
for field in fields:
if field in result and result[field] is not None:
# 递归翻译字段值(可能是字符串、列表或嵌套字典)
try:
result[field] = await Translation_English(modid, result[field], fields)
except Exception as e:
logger.warning(f"翻译字段 {field} 失败: {e}")
# 翻译失败时保留原值
continue
return result
# 其他类型数字、布尔值、None等原样返回
else:
return text
class MemoryBaseService:
"""记忆服务基类,提供共享的辅助方法"""
@@ -294,4 +545,4 @@ class MemoryBaseService:
except Exception as e:
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
return 0
return 0

View File

@@ -16,6 +16,7 @@ import json
from datetime import datetime
from app.schemas.memory_episodic_schema import EmotionType
from app.services.memory_base_service import Translation_English
logger = logging.getLogger(__name__)
@@ -24,7 +25,7 @@ class MemoryEntityService:
self.id = id
self.table = table
self.connector = Neo4jConnector()
async def get_timeline_memories_server(self):
async def get_timeline_memories_server(self,model_id, language_type):
"""
获取时间线记忆数据
@@ -48,10 +49,10 @@ class MemoryEntityService:
logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}")
# 根据表类型选择查询
if self.table == 'Statement':
if self.table == 'Statement':
# Statement只需要输入ID使用简化查询
results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id)
elif self.table == 'ExtractedEntity':
elif self.table == 'ExtractedEntity':
# ExtractedEntity类型查询
results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id)
else:
@@ -62,7 +63,7 @@ class MemoryEntityService:
logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}")
# 处理查询结果
timeline_data = self._process_timeline_results(results)
timeline_data =await self._process_timeline_results(results, model_id, language_type)
logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))}")
@@ -71,12 +72,14 @@ class MemoryEntityService:
except Exception as e:
logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True)
return str(e)
def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]:
"""
处理时间线查询结果
Args:
results: Neo4j查询结果
model_id: 模型ID用于翻译
language_type: 语言类型 ('zh' 或其他)
Returns:
处理后的时间线数据字典
@@ -104,19 +107,19 @@ class MemoryEntityService:
# 处理MemorySummary
summary = data.get('MemorySummary')
if summary is not None:
processed_summary = self._process_field_value(summary, "MemorySummary")
processed_summary = await self._process_field_value(summary, "MemorySummary")
memory_summary_list.extend(processed_summary)
# 处理Statement
statement = data.get('statement')
if statement is not None:
processed_statement = self._process_field_value(statement, "Statement")
processed_statement = await self._process_field_value(statement, "Statement")
statement_list.extend(processed_statement)
# 处理ExtractedEntity
extracted_entity = data.get('ExtractedEntity')
if extracted_entity is not None:
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity")
extracted_entity_list.extend(processed_entity)
# 去重 - 现在处理的是字典列表,需要更智能的去重
@@ -128,6 +131,8 @@ class MemoryEntityService:
all_timeline_data = memory_summary_list + statement_list
all_timeline_data = self._merge_same_text_items(all_timeline_data)
# 如果需要翻译(非中文),对整个结果进行翻译
result = {
"MemorySummary": memory_summary_list,
"Statement": statement_list,
@@ -233,7 +238,7 @@ class MemoryEntityService:
except Exception:
return False
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
"""
处理字段值,支持字符串、列表等类型
@@ -251,13 +256,13 @@ class MemoryEntityService:
# 如果是列表,处理每个元素
for item in value:
if self._is_valid_item(item):
processed_item = self._process_single_item(item)
processed_item = await self._process_single_item(item)
if processed_item:
processed_values.append(processed_item)
elif isinstance(value, dict):
# 如果是字典,直接处理
if self._is_valid_item(value):
processed_item = self._process_single_item(value)
processed_item = await self._process_single_item(value)
if processed_item:
processed_values.append(processed_item)
elif isinstance(value, str):
@@ -304,7 +309,7 @@ class MemoryEntityService:
return (str(item).strip() != '' and
"MemorySummaryChunk" not in str(item))
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
处理单个项目
@@ -369,6 +374,117 @@ class MemoryEntityService:
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
return str(dt) if dt is not None else None
async def _translate_list(
self,
data_list: List[Dict[str, Any]],
model_id: str,
fields: List[str]
) -> List[Dict[str, Any]]:
"""
翻译列表中每个字典的指定字段(并发有限度以降低整体延迟)
Args:
data_list: 要翻译的字典列表
model_id: 模型ID
fields: 需要翻译的字段列表
Returns:
翻译后的字典列表
"""
# 空列表或无字段时直接返回
if not data_list or not fields:
return data_list
import asyncio
# 并发限制,避免一次性发起过多请求
# 可根据实际情况调整(建议 5-10
concurrency_limit = 5
semaphore = asyncio.Semaphore(concurrency_limit)
async def translate_single_field(
index: int,
field: str,
value: Any,
) -> Optional[tuple]:
"""
翻译单个字段并返回 (索引, 字段名, 翻译结果)
Returns:
(index, field, translated_value) 或 None如果跳过
"""
# 跳过空值
if value is None or value == "":
return None
# 统一转成字符串再翻译,防止非字符串类型导致错误
text = str(value)
try:
async with semaphore:
# 调用 Translation_English 进行翻译
# 注意Translation_English 的参数顺序是 (model_id, text)
translated = await Translation_English(model_id, text)
# 如果翻译结果为空,保留原值
if translated is None or translated == "":
return None
return index, field, translated
except Exception as e:
logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}")
return None
# 构造所有需要翻译的任务
tasks = []
for idx, item in enumerate(data_list):
# 防御性检查:确保 item 是字典
if not isinstance(item, dict):
continue
for field in fields:
if field not in item:
continue
value = item.get(field)
# 对于 None 或空字符串的值,直接跳过,不创建任务
if value is None or value == "":
continue
tasks.append(
asyncio.create_task(
translate_single_field(idx, field, value)
)
)
# 如果没有需要翻译的任务,直接返回原列表
if not tasks:
return data_list
# 使用 gather 并发执行翻译任务(受 semaphore 限制)
# return_exceptions=True 可以防止单个任务失败导致整体失败
results = await asyncio.gather(*tasks, return_exceptions=True)
# 创建深拷贝以避免修改原始数据
translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list]
# 将翻译结果回填到列表
for result in results:
# 跳过 None 结果和异常
if result is None or isinstance(result, Exception):
if isinstance(result, Exception):
logger.warning(f"翻译任务异常: {result}")
continue
idx, field, translated = result
# 防御性检查索引范围
if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict):
translated_list[idx][field] = translated
return translated_list
@@ -426,15 +542,19 @@ class MemoryEmotion:
# 如果解析失败,返回原始字符串
return iso_string
async def get_emotion(self) -> Dict[str, Any]:
async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]:
"""
获取情绪随时间变化数据
Args:
model_id: 模型ID用于翻译
language_type: 语言类型 ('zh' 或其他)
Returns:
包含情绪数据的字典
"""
try:
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}")
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}")
if self.table == 'Statement':
results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id)
@@ -450,6 +570,10 @@ class MemoryEmotion:
# 转换Neo4j类型
final_data = self._convert_neo4j_types(emotion_data)
# 如果需要翻译(非中文)
if language_type != 'zh' and model_id and final_data:
final_data = await self._translate_emotion_data(final_data, model_id)
logger.info(f"成功获取 {len(final_data)} 条情绪数据")
return final_data
@@ -590,16 +714,14 @@ class MemoryInteraction:
"""
try:
logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}")
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
if ori_data!=[]:
# name = ori_data[0]['name']
group_id = ori_data[0]['group_id']
group_id = [i['group_id'] for i in ori_data][0]
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
if not Space_User:
return []
user_id=Space_User[0]['id']
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)

View File

@@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
@@ -360,7 +360,9 @@ class UserMemoryService:
async def get_cached_memory_insight(
self,
db: Session,
end_user_id: str
end_user_id: str,
model_id: str,
language_type: str
) -> Dict[str, Any]:
"""
从数据库获取缓存的记忆洞察(四个维度)
@@ -419,11 +421,18 @@ class UserMemoryService:
key_findings_array = []
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)")
memory_insight=end_user.memory_insight
behavior_pattern=end_user.behavior_pattern
growth_trajectory=end_user.growth_trajectory
if language_type!='zh':
memory_insight=await Translation_English(model_id,memory_insight)
behavior_pattern=await Translation_English(model_id,behavior_pattern)
growth_trajectory=await Translation_English(model_id,growth_trajectory)
return {
"memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight
"behavior_pattern": end_user.behavior_pattern,
"memory_insight":memory_insight, # 总体概述存储在 memory_insight
"behavior_pattern":behavior_pattern,
"key_findings": key_findings_array, # 返回数组
"growth_trajectory": end_user.growth_trajectory,
"growth_trajectory": growth_trajectory,
"updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at),
"is_cached": True
}
@@ -457,7 +466,9 @@ class UserMemoryService:
async def get_cached_user_summary(
self,
db: Session,
end_user_id: str
end_user_id: str,
model_id:str,
language_type:str="zh"
) -> Dict[str, Any]:
"""
从数据库获取缓存的用户摘要(四个部分)
@@ -481,7 +492,6 @@ class UserMemoryService:
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
@@ -495,20 +505,29 @@ class UserMemoryService:
}
# 检查是否有缓存数据(至少有一个字段不为空)
user_summary=end_user.user_summary
personality_traits=end_user.personality_traits
core_values=end_user.core_values
one_sentence_summary=end_user.one_sentence_summary
if language_type!='zh':
user_summary=await Translation_English(model_id, user_summary)
personality_traits = await Translation_English(model_id, personality_traits)
core_values = await Translation_English(model_id, core_values)
one_sentence_summary = await Translation_English(model_id, one_sentence_summary)
has_cache = any([
end_user.user_summary,
end_user.personality_traits,
end_user.core_values,
end_user.one_sentence_summary
user_summary,
personality_traits,
core_values,
one_sentence_summary
])
if has_cache:
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
return {
"user_summary": end_user.user_summary,
"personality": end_user.personality_traits,
"core_values": end_user.core_values,
"one_sentence": end_user.one_sentence_summary,
"user_summary": user_summary,
"personality": personality_traits,
"core_values":core_values,
"one_sentence": one_sentence_summary,
"updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at),
"is_cached": True
}
@@ -1367,7 +1386,6 @@ async def analytics_memory_types(
return memory_types
async def analytics_graph_data(
db: Session,
end_user_id: str,
@@ -1557,7 +1575,7 @@ async def analytics_graph_data(
f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={len(nodes)}, edges={len(edges)}"
)
return {
"nodes": nodes,
"edges": edges,
@@ -1606,11 +1624,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
# 获取该节点类型的白名单字段
allowed_fields = field_whitelist.get(label, [])
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
# if not allowed_fields:
# # 对于未定义的节点类型,只返回基本字段
# allowed_fields = ["name", "created_at", "caption"]
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
node_results = await (_neo4j_connector.execute_query(count_neo4j))
# 提取白名单中的字段
@@ -1618,13 +1632,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
for field in allowed_fields:
if field in properties:
value = properties[field]
if str(field) == 'entity_type':
if str(field) == 'entity_type':
value=type_mapping.get(value,'')
if str(field)=="emotion_type":
value=EmotionType.EMOTION_MAPPING.get(value)
if str(field)=="emotion_subject":
if str(field)=="emotion_subject":
value=EmotionSubject.SUBJECT_MAPPING.get(value)
# 清理 Neo4j 特殊类型
filtered_props[field] = _clean_neo4j_value(value)
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
return filtered_props