Merge remote-tracking branch 'origin/develop' into develop
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
@@ -661,6 +661,11 @@ async def draft_run(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
return fail(
|
||||
msg="未知应用类型",
|
||||
code=422
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
@@ -616,8 +616,10 @@ async def get_knowledge_type_stats_api(
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
language_type: Optional[str] ="zh",
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
@@ -628,10 +630,22 @@ async def get_hot_memory_tags_by_user_api(
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
model_id=model_id,
|
||||
limit=limit
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
|
||||
@@ -20,6 +20,7 @@ router = APIRouter(
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:Optional[str] = "zh",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
@@ -20,7 +21,7 @@ from app.services.user_memory_service import (
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
@@ -44,6 +45,7 @@ router = APIRouter(
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
language_type: str = "zh",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -53,10 +55,18 @@ async def get_memory_insight_report_api(
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
@@ -72,6 +82,7 @@ async def get_memory_insight_report_api(
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str="zh",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -81,10 +92,18 @@ async def get_user_summary_api(
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -253,7 +272,6 @@ async def get_graph_data_api(
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
@@ -278,7 +296,13 @@ async def get_end_user_profile(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
@@ -296,7 +320,6 @@ async def get_end_user_profile(
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -396,12 +419,21 @@ async def update_end_user_profile(
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
|
||||
@@ -54,6 +54,8 @@ class WorkflowExecutor:
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
@@ -131,77 +133,12 @@ class WorkflowExecutor:
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": False # loop runing flag, only use in loop node,not use in main loop
|
||||
"looping": False, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {
|
||||
self.start_node_id: True
|
||||
}
|
||||
}
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""分析 End 节点的前缀配置
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
|
||||
Returns:
|
||||
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
|
||||
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
continue
|
||||
|
||||
# 找到所有直接连接到 End 节点的上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 查找模板中引用了哪些节点
|
||||
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
|
||||
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
@@ -231,10 +168,12 @@ class WorkflowExecutor:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
graph = GraphBuilder(
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
).build()
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
graph = builder.build()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
return graph
|
||||
@@ -375,13 +314,15 @@ class WorkflowExecutor:
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
continue
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[NODE-START] Node starts execution: {node_name} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
@@ -390,18 +331,17 @@ class WorkflowExecutor:
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[NODE-END] Node execution completed: {node_name} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
@@ -410,7 +350,7 @@ class WorkflowExecutor:
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"state": result.get("node_outputs", {}).get(node_name),
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,7 +30,10 @@ class GraphBuilder:
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
@@ -39,74 +44,98 @@ class GraphBuilder:
|
||||
return self.workflow_config.get("edges", [])
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""分析 End 节点的前缀配置
|
||||
"""
|
||||
Analyze the prefix configuration for End nodes.
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
This function scans each End node's output template, identifies
|
||||
references to its direct upstream nodes, and extracts the prefix
|
||||
string appearing before the first reference.
|
||||
|
||||
Returns:
|
||||
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||
tuple:
|
||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
|
||||
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
continue
|
||||
|
||||
# 查找模板中引用了哪些节点
|
||||
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||
# Find all node references in the template
|
||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
|
||||
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
||||
|
||||
# 找到所有直接连接到 End 节点的上游节点
|
||||
# Identify all direct upstream nodes connected to the End node
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def add_nodes(self):
|
||||
"""Add all nodes from the workflow configuration to the state graph.
|
||||
|
||||
This method handles:
|
||||
- Creation of node instances using NodeFactory.
|
||||
- Special handling for start, end, and cycle nodes.
|
||||
- Injection of End node prefixes for streaming mode.
|
||||
- Marking nodes as adjacent to End nodes if referenced.
|
||||
- Wrapping node run methods as async functions or async generators
|
||||
depending on streaming mode.
|
||||
|
||||
Notes:
|
||||
Loop nodes (nodes with `cycle` property) are handled separately
|
||||
via CycleGraphNode when building subgraphs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Analyze End node prefixes if in stream mode
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||
|
||||
for node in self.nodes:
|
||||
@@ -114,21 +143,21 @@ class GraphBuilder:
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
||||
if not self.subgraph:
|
||||
continue
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
# Record start and end node IDs
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node_id
|
||||
elif node_type == NodeType.END:
|
||||
self.end_node_ids.append(node_id)
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
# Create node instance (start and end nodes are also created)
|
||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
|
||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
|
||||
if node_type in BRANCH_NODES:
|
||||
|
||||
# Find all edges whose source is the current node
|
||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||
@@ -142,26 +171,23 @@ class GraphBuilder:
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||
# Inject End node prefix configuration if in stream mode
|
||||
if self.stream and node_id in end_prefixes:
|
||||
# 将 End 前缀配置注入到节点实例
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||
logger.info(f"Injected End prefix for node {node_id}")
|
||||
|
||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
||||
if self.stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
||||
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
# Wrap node's run method to avoid closure issues
|
||||
if self.stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
# Stream mode: create an async generator function
|
||||
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
|
||||
def make_stream_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
|
||||
@@ -169,7 +195,7 @@ class GraphBuilder:
|
||||
|
||||
self.graph.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
# Non-stream mode: create an async function
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
@@ -178,45 +204,110 @@ class GraphBuilder:
|
||||
|
||||
self.graph.add_node(node_id, make_func(node_instance))
|
||||
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
|
||||
logger.debug(f"Added node: {node_id} (type={node_type}, stream={self.stream})")
|
||||
|
||||
def add_edges(self):
|
||||
"""Add all edges (normal, waiting, and conditional) to the state graph.
|
||||
|
||||
This method handles:
|
||||
- Connecting the START node to the workflow's start node.
|
||||
- Collecting waiting edges for nodes with multiple sources.
|
||||
- Collecting conditional edges for routing to NOP nodes.
|
||||
- Adding NOP nodes for conditional branches to allow later merging.
|
||||
- Wrapping routing logic in a router function that evaluates conditions.
|
||||
- Connecting End nodes to the global END node.
|
||||
|
||||
Notes:
|
||||
- NOP nodes are used to ensure that multiple branches can merge
|
||||
correctly without modifying the workflow state.
|
||||
- Waiting edges are automatically handled by LangGraph to schedule
|
||||
nodes only after all sources are activated.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Connect the START node to the workflow's start node
|
||||
if self.start_node_id:
|
||||
self.graph.add_edge(START, self.start_node_id)
|
||||
logger.debug(f"添加边: START -> {self.start_node_id}")
|
||||
logger.debug(f"Added edge: START -> {self.start_node_id}")
|
||||
|
||||
# Collect all sources for each target node for normal/waiting edges
|
||||
waiting_edges = defaultdict(list)
|
||||
# Collect all conditional edges for each source node to construct routing
|
||||
conditional_edges = defaultdict(list)
|
||||
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == self.start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
self.graph.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# # 处理到 end 节点的边
|
||||
# if target in end_node_ids:
|
||||
# # 连接到 end 节点
|
||||
# workflow.add_edge(source, target)
|
||||
# logger.debug(f"添加边: {source} -> {target}")
|
||||
# continue
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
# Skip error edges (handled within nodes)
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def make_router(cond, tgt):
|
||||
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
||||
# Conditional edges: group by source node
|
||||
conditional_edges[source].append({
|
||||
"target": target,
|
||||
"condition": condition,
|
||||
"label": edge.get("label")
|
||||
})
|
||||
else:
|
||||
# Normal edges: group by target node (used for waiting edges)
|
||||
waiting_edges[target].append(source)
|
||||
|
||||
def router_fn(state: WorkflowState):
|
||||
# Add conditional edges
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
return {
|
||||
"activate": {
|
||||
node_id: s["activate"][node_name]
|
||||
for node_id in targets
|
||||
}
|
||||
}
|
||||
|
||||
return node
|
||||
|
||||
unique_branch = {}
|
||||
for branch in branch_list:
|
||||
if branch.get("label") not in unique_branch.keys():
|
||||
nop_node_name = f"nop_{uuid.uuid4().hex[:8]}"
|
||||
logger.info(f"Binding NOP: {source_node} {branch.get('label')} -> {nop_node_name}")
|
||||
unique_branch[branch["label"]] = {
|
||||
"condition": branch["condition"],
|
||||
"node": {
|
||||
"name": nop_node_name,
|
||||
},
|
||||
"target": [branch["target"]]
|
||||
}
|
||||
else:
|
||||
unique_branch[branch["label"]]["target"].append(branch["target"])
|
||||
|
||||
# Add NOP nodes and connect them to downstream nodes
|
||||
for label, branch_info in unique_branch.items():
|
||||
self.graph.add_node(
|
||||
branch_info["node"]["name"],
|
||||
make_branch_node(
|
||||
branch_info["node"]["name"],
|
||||
branch_info["target"]
|
||||
)
|
||||
)
|
||||
for target in branch_info["target"]:
|
||||
waiting_edges[target].append(branch_info["node"]["name"])
|
||||
|
||||
def router_fn(state: WorkflowState) -> list[Send]:
|
||||
branch_activate = []
|
||||
new_state = state.copy()
|
||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||
|
||||
for label, branch in unique_branch.items():
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
branch["condition"],
|
||||
state.get("variables", {}),
|
||||
state.get("runtime_vars", {}),
|
||||
{
|
||||
@@ -225,30 +316,45 @@ class GraphBuilder:
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END
|
||||
logger.debug(f"Conditional routing {src}: selected branch {label}")
|
||||
new_state["activate"][branch["node"]["name"]] = True
|
||||
continue
|
||||
new_state["activate"][branch["node"]["name"]] = False
|
||||
for label, branch in unique_branch.items():
|
||||
branch_activate.append(
|
||||
Send(
|
||||
branch['node']['name'],
|
||||
new_state
|
||||
)
|
||||
)
|
||||
return branch_activate
|
||||
|
||||
# 动态修改函数名,避免重复
|
||||
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
|
||||
return router_fn
|
||||
# Dynamically set function name
|
||||
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{src}"
|
||||
return router_fn
|
||||
|
||||
router_fn = make_router(condition, target)
|
||||
self.graph.add_conditional_edges(source, router_fn)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
router_fn = make_router(source_node, branches)
|
||||
self.graph.add_conditional_edges(source_node, router_fn)
|
||||
logger.debug(f"Added conditional edges: {source_node} -> {[b['target'] for b in branches]}")
|
||||
|
||||
# Add normal/waiting edges
|
||||
for target, sources in waiting_edges.items():
|
||||
if len(sources) == 1:
|
||||
# Single source: normal edge
|
||||
self.graph.add_edge(sources[0], target)
|
||||
logger.debug(f"Added edge: {sources[0]} -> {target}")
|
||||
else:
|
||||
# 普通边
|
||||
self.graph.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
# Multiple sources: waiting edge
|
||||
self.graph.add_edge(sources, target)
|
||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
# Connect End nodes to the global END node
|
||||
for end_node_id in self.end_node_ids:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges() # 添加边必须在添加节点之后
|
||||
checkpointer = InMemorySaver()
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
||||
return self.graph
|
||||
|
||||
@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
|
||||
@@ -7,18 +7,26 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merget_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""Workflow state
|
||||
|
||||
@@ -60,6 +68,9 @@ class WorkflowState(TypedDict):
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merget_activate_state]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类
|
||||
@@ -84,6 +95,47 @@ class BaseNode(ABC):
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
|
||||
self.variable_updater = False
|
||||
|
||||
def check_activate(self, state: WorkflowState):
|
||||
"""Check if the current node is activated in the workflow state.
|
||||
|
||||
Args:
|
||||
state (WorkflowState): The current workflow state containing the 'activate' dict.
|
||||
|
||||
Returns:
|
||||
bool: True if the node is activated, False otherwise.
|
||||
"""
|
||||
return state["activate"][self.node_id]
|
||||
|
||||
def trans_activate(self, state: WorkflowState):
|
||||
"""Transform the activation state for downstream nodes.
|
||||
|
||||
This method collects all downstream nodes (excluding branch nodes)
|
||||
connected to the current node and returns a dict indicating whether
|
||||
each of these nodes should be activated based on the current node's state.
|
||||
The current node itself is also included in the returned activation dict.
|
||||
|
||||
Args:
|
||||
state (WorkflowState): The current workflow state.
|
||||
|
||||
Returns:
|
||||
dict: A dict with a single key 'activate', mapping node IDs to
|
||||
their activation status (True/False).
|
||||
"""
|
||||
edges = self.workflow_config.get("edges")
|
||||
under_stream_nodes = [
|
||||
edge.get("target")
|
||||
for edge in edges
|
||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
||||
]
|
||||
return {
|
||||
"activate": {
|
||||
node_id: self.check_activate(state)
|
||||
for node_id in under_stream_nodes
|
||||
} | {self.node_id: self.check_activate(state)}
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""执行节点业务逻辑(非流式)
|
||||
@@ -99,13 +151,13 @@ class BaseNode(ABC):
|
||||
|
||||
Examples:
|
||||
>>> # LLM 节点
|
||||
>>> return "这是 AI 的回复"
|
||||
>>> "这是 AI 的回复"
|
||||
|
||||
>>> # Transform 节点
|
||||
>>> return {"processed_data": [...]}
|
||||
>>> {"processed_data": [...]}
|
||||
|
||||
>>> # Start/End 节点
|
||||
>>> return {"message": "开始", "conversation_id": "xxx"}
|
||||
>>> {"message": "开始", "conversation_id": "xxx"}
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -126,14 +178,14 @@ class BaseNode(ABC):
|
||||
业务数据(chunk)或完成标记
|
||||
|
||||
Examples:
|
||||
>>> # 流式 LLM 节点
|
||||
>>> full_response = ""
|
||||
>>> async for chunk in llm.astream(prompt):
|
||||
... full_response += chunk
|
||||
... yield chunk # yield 文本片段
|
||||
>>>
|
||||
>>> # 最后 yield 完成标记
|
||||
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
# 流式 LLM 节点
|
||||
full_response = ""
|
||||
async for chunk in llm.astream(prompt):
|
||||
full_response += chunk
|
||||
yield chunk # yield 文本片段
|
||||
|
||||
# 最后 yield 完成标记
|
||||
yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
"""
|
||||
result = await self.execute(state)
|
||||
# 默认实现:直接 yield 完成标记
|
||||
@@ -146,7 +198,7 @@ class BaseNode(ABC):
|
||||
是否支持流式输出
|
||||
"""
|
||||
# 检查子类是否重写了 execute_stream 方法
|
||||
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
||||
return self.__class__.execute_stream is not BaseNode.execute_stream
|
||||
|
||||
def get_timeout(self) -> int:
|
||||
"""获取超时时间(秒)
|
||||
@@ -172,6 +224,9 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
return self.trans_activate(state)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
@@ -204,12 +259,11 @@ class BaseNode(ABC):
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
"looping": state["looping"]
|
||||
}
|
||||
} | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -220,7 +274,7 @@ class BaseNode(ABC):
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
This method is called by the Executor and is responsible for:
|
||||
@@ -241,6 +295,11 @@ class BaseNode(ABC):
|
||||
Yields:
|
||||
State updates with streaming buffer and final result
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"跳过节点{self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
@@ -358,7 +417,6 @@ class BaseNode(ABC):
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
@@ -377,7 +435,7 @@ class BaseNode(ABC):
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update
|
||||
yield state_update | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -427,12 +485,13 @@ class BaseNode(ABC):
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
}
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
}
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
}
|
||||
if self.variable_updater:
|
||||
final_output = final_output | {"variables": state["variables"]}
|
||||
|
||||
return final_output
|
||||
|
||||
def _wrap_error(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,9 @@ class NodeType(StrEnum):
|
||||
MEMORY_WRITE = "memory-write"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
EMPTY = "empty"
|
||||
NOT_EMPTY = "not_empty"
|
||||
|
||||
@@ -11,6 +11,7 @@ class EmotionTagsRequest(BaseModel):
|
||||
start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)")
|
||||
end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)")
|
||||
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
|
||||
class EmotionWordcloudRequest(BaseModel):
|
||||
@@ -18,20 +19,24 @@ class EmotionWordcloudRequest(BaseModel):
|
||||
group_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
|
||||
class EmotionHealthRequest(BaseModel):
|
||||
"""获取情绪健康指数请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
time_range: str = Field("30d", description="时间范围(7d/30d/90d)")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
|
||||
class EmotionSuggestionsRequest(BaseModel):
|
||||
"""获取个性化情绪建议请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
|
||||
class EmotionGenerateSuggestionsRequest(BaseModel):
|
||||
"""生成个性化情绪建议请求"""
|
||||
end_user_id: str = Field(..., description="终端用户ID")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
@@ -44,6 +44,7 @@ class EndUserProfileResponse(BaseModel):
|
||||
updatetime_profile: Optional[datetime.datetime] = Field(description="核心档案信息最后更新时间", default=None)
|
||||
|
||||
|
||||
|
||||
class EndUserProfileUpdate(BaseModel):
|
||||
"""终端用户基本信息更新请求模型"""
|
||||
end_user_id: str = Field(description="终端用户ID")
|
||||
|
||||
@@ -51,6 +51,7 @@ class EpisodicMemoryOverviewRequest(BaseModel):
|
||||
"""情景记忆总览查询请求"""
|
||||
|
||||
end_user_id: str = Field(..., description="终端用户ID")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
time_range: str = Field(
|
||||
default="all",
|
||||
description="时间范围筛选,可选值:all, today, this_week, this_month"
|
||||
@@ -70,3 +71,4 @@ class EpisodicMemoryDetailsRequest(BaseModel):
|
||||
|
||||
end_user_id: str = Field(..., description="终端用户ID")
|
||||
summary_id: str = Field(..., description="情景记忆摘要ID")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""
|
||||
显性记忆的请求和响应模型
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ExplicitMemoryOverviewRequest(BaseModel):
|
||||
"""显性记忆总览查询请求"""
|
||||
|
||||
end_user_id: str = Field(..., description="终端用户ID")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
class ExplicitMemoryDetailsRequest(BaseModel):
|
||||
"""显性记忆详情查询请求"""
|
||||
|
||||
end_user_id: str = Field(..., description="终端用户ID")
|
||||
memory_id: str = Field(..., description="记忆ID(情景记忆或语义记忆的ID)")
|
||||
language_type: Optional[str] = Field("zh", description="语言类型(zh/en)")
|
||||
|
||||
@@ -1445,7 +1445,7 @@ class AppService:
|
||||
target_workspace_ids: List[uuid.UUID],
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> AppShare:
|
||||
) -> list[AppShare]:
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
Args:
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
@@ -692,7 +693,9 @@ class MemoryAgentService:
|
||||
async def get_hot_memory_tags_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
limit: int = 20,
|
||||
model_id: Optional[str] = None,
|
||||
language_type: Optional[str] = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
@@ -710,7 +713,13 @@ class MemoryAgentService:
|
||||
try:
|
||||
# by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度)
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||
payload = [{"name": t, "frequency": f} for t, f in tags]
|
||||
payload=[]
|
||||
for tag, freq in tags:
|
||||
if language_type!="zh":
|
||||
tag=await Translation_English(model_id, tag)
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
else:
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"热门记忆标签查询失败: {e}")
|
||||
|
||||
@@ -3,17 +3,268 @@ Memory Base Service
|
||||
|
||||
提供记忆服务的基础功能和共享辅助方法。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.db import get_db_context
|
||||
logger = get_logger(__name__)
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
data: str
|
||||
|
||||
class MemoryTransService:
|
||||
"""记忆翻译服务,提供中英文翻译功能"""
|
||||
|
||||
def __init__(self, llm_client=None, model_id: Optional[str] = None):
|
||||
"""
|
||||
初始化翻译服务
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端实例或模型ID字符串(可选)
|
||||
model_id: 模型ID,用于初始化LLM客户端(可选)
|
||||
|
||||
Note:
|
||||
- 如果llm_client是字符串,会被当作model_id使用
|
||||
- 如果同时提供llm_client和model_id,优先使用llm_client
|
||||
"""
|
||||
# 处理llm_client参数:如果是字符串,当作model_id
|
||||
if isinstance(llm_client, str):
|
||||
self.model_id = llm_client
|
||||
self.llm_client = None
|
||||
else:
|
||||
self.llm_client = llm_client
|
||||
self.model_id = model_id
|
||||
|
||||
self._initialized = False
|
||||
|
||||
def _ensure_llm_client(self):
|
||||
"""确保LLM客户端已初始化"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
if self.llm_client is None:
|
||||
if self.model_id:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
model_config = config_service.get_model_config(self.model_id)
|
||||
|
||||
extra_params = {
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 400,
|
||||
"top_p": 0.8,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url"),
|
||||
timeout=model_config.get("timeout", 30),
|
||||
max_retries=model_config.get("max_retries", 3),
|
||||
extra_params=extra_params
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
else:
|
||||
raise ValueError("必须提供 llm_client 或 model_id 之一")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def translate_to_english(self, text: str) -> str:
|
||||
"""
|
||||
将中文翻译为英文
|
||||
|
||||
Args:
|
||||
text: 要翻译的中文文本
|
||||
|
||||
Returns:
|
||||
翻译后的英文文本
|
||||
"""
|
||||
self._ensure_llm_client()
|
||||
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=translation_messages,
|
||||
response_model=TranslationResponse
|
||||
)
|
||||
return response.data
|
||||
except Exception as e:
|
||||
logger.error(f"翻译失败: {str(e)}")
|
||||
return text # 翻译失败时返回原文
|
||||
|
||||
async def is_english(self, text: str) -> bool:
|
||||
"""
|
||||
检查文本是否为英文
|
||||
|
||||
Args:
|
||||
text: 要检查的文本(必须是字符串)
|
||||
|
||||
Returns:
|
||||
True 如果文本主要是英文,False 否则
|
||||
|
||||
Note:
|
||||
- 只接受字符串类型
|
||||
- 检查是否主要由英文字母和常见标点组成
|
||||
- 允许数字、空格和常见标点符号
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}")
|
||||
|
||||
if not text.strip():
|
||||
return True # 空字符串视为英文
|
||||
|
||||
# 更宽松的英文检查:允许字母、数字、空格和常见标点
|
||||
# 如果文本中英文字符占比超过 80%,认为是英文
|
||||
english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-'))
|
||||
total_chars = len(text)
|
||||
|
||||
if total_chars == 0:
|
||||
return True
|
||||
|
||||
return (english_chars / total_chars) >= 0.8
|
||||
async def Translate(self, text: str, target_language: str = "en") -> str:
|
||||
"""
|
||||
通用翻译方法(保持向后兼容)
|
||||
|
||||
Args:
|
||||
text: 要翻译的文本
|
||||
target_language: 目标语言,"en"表示英文,"zh"表示中文
|
||||
|
||||
Returns:
|
||||
翻译后的文本
|
||||
"""
|
||||
if target_language == "en":
|
||||
return await self.translate_to_english(text)
|
||||
else:
|
||||
logger.warning(f"不支持的目标语言: {target_language},返回原文")
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# 测试翻译服务
|
||||
async def Translation_English(modid, text, fields=None):
|
||||
"""
|
||||
将数据翻译为英文(支持字段级翻译)
|
||||
|
||||
Args:
|
||||
modid: 模型ID
|
||||
text: 要翻译的数据(可以是字符串、字典或列表)
|
||||
fields: 需要翻译的字段列表(可选)
|
||||
如果为None,默认翻译: ['content', 'summary', 'statement', 'description',
|
||||
'name', 'aliases', 'caption', 'emotion_keywords']
|
||||
|
||||
Returns:
|
||||
翻译后的数据,保持原有结构
|
||||
|
||||
Note:
|
||||
- 对于字符串:直接翻译
|
||||
- 对于列表:递归处理每个元素,保持列表长度和索引不变
|
||||
- 对于字典:只翻译指定字段(fields参数)
|
||||
- 对于其他类型:原样返回
|
||||
"""
|
||||
trans_service = MemoryTransService(modid)
|
||||
|
||||
# 处理字符串类型
|
||||
if isinstance(text, str):
|
||||
# 空字符串直接返回
|
||||
if not text.strip():
|
||||
return text
|
||||
|
||||
try:
|
||||
is_eng = await trans_service.is_english(text)
|
||||
if not is_eng:
|
||||
english_result = await trans_service.Translate(text)
|
||||
return english_result
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字符串失败: {e}")
|
||||
return text
|
||||
|
||||
# 处理列表类型
|
||||
elif isinstance(text, list):
|
||||
english_result = []
|
||||
for item in text:
|
||||
# 递归处理列表中的每个元素
|
||||
if isinstance(item, str):
|
||||
# 字符串元素:检查是否需要翻译
|
||||
if not item.strip():
|
||||
english_result.append(item)
|
||||
continue
|
||||
|
||||
try:
|
||||
is_eng = await trans_service.is_english(item)
|
||||
if not is_eng:
|
||||
translated = await trans_service.Translate(item)
|
||||
english_result.append(translated)
|
||||
else:
|
||||
# 保留英文项,不改变列表长度
|
||||
english_result.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译列表项失败: {e}")
|
||||
english_result.append(item)
|
||||
|
||||
elif isinstance(item, dict):
|
||||
# 字典元素:递归调用自己处理字典
|
||||
translated_dict = await Translation_English(modid, item, fields)
|
||||
english_result.append(translated_dict)
|
||||
|
||||
elif isinstance(item, list):
|
||||
# 嵌套列表:递归处理
|
||||
translated_list = await Translation_English(modid, item, fields)
|
||||
english_result.append(translated_list)
|
||||
|
||||
else:
|
||||
# 其他类型(数字、布尔值等):原样保留
|
||||
english_result.append(item)
|
||||
|
||||
return english_result
|
||||
|
||||
# 处理字典类型
|
||||
elif isinstance(text, dict):
|
||||
# 确定要翻译的字段
|
||||
if fields is None:
|
||||
# 默认翻译字段
|
||||
fields = [
|
||||
'content', 'summary', 'statement', 'description',
|
||||
'name', 'aliases', 'caption', 'emotion_keywords',
|
||||
'text', 'title', 'label', 'type' # 添加常用字段
|
||||
]
|
||||
|
||||
# 创建副本,避免修改原始数据
|
||||
result = text.copy()
|
||||
|
||||
for field in fields:
|
||||
if field in result and result[field] is not None:
|
||||
# 递归翻译字段值(可能是字符串、列表或嵌套字典)
|
||||
try:
|
||||
result[field] = await Translation_English(modid, result[field], fields)
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字段 {field} 失败: {e}")
|
||||
# 翻译失败时保留原值
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
# 其他类型(数字、布尔值、None等):原样返回
|
||||
else:
|
||||
return text
|
||||
class MemoryBaseService:
|
||||
"""记忆服务基类,提供共享的辅助方法"""
|
||||
|
||||
@@ -294,4 +545,4 @@ class MemoryBaseService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
return 0
|
||||
@@ -16,6 +16,7 @@ import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.memory_episodic_schema import EmotionType
|
||||
from app.services.memory_base_service import Translation_English
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +25,7 @@ class MemoryEntityService:
|
||||
self.id = id
|
||||
self.table = table
|
||||
self.connector = Neo4jConnector()
|
||||
async def get_timeline_memories_server(self):
|
||||
async def get_timeline_memories_server(self,model_id, language_type):
|
||||
"""
|
||||
获取时间线记忆数据
|
||||
|
||||
@@ -48,10 +49,10 @@ class MemoryEntityService:
|
||||
logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
# 根据表类型选择查询
|
||||
if self.table == 'Statement':
|
||||
if self.table == 'Statement':
|
||||
# Statement只需要输入ID,使用简化查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id)
|
||||
elif self.table == 'ExtractedEntity':
|
||||
elif self.table == 'ExtractedEntity':
|
||||
# ExtractedEntity类型查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id)
|
||||
else:
|
||||
@@ -62,7 +63,7 @@ class MemoryEntityService:
|
||||
logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}")
|
||||
|
||||
# 处理查询结果
|
||||
timeline_data = self._process_timeline_results(results)
|
||||
timeline_data =await self._process_timeline_results(results, model_id, language_type)
|
||||
|
||||
logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条")
|
||||
|
||||
@@ -71,12 +72,14 @@ class MemoryEntityService:
|
||||
except Exception as e:
|
||||
logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True)
|
||||
return str(e)
|
||||
def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理时间线查询结果
|
||||
|
||||
Args:
|
||||
results: Neo4j查询结果
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
处理后的时间线数据字典
|
||||
@@ -104,19 +107,19 @@ class MemoryEntityService:
|
||||
# 处理MemorySummary
|
||||
summary = data.get('MemorySummary')
|
||||
if summary is not None:
|
||||
processed_summary = self._process_field_value(summary, "MemorySummary")
|
||||
processed_summary = await self._process_field_value(summary, "MemorySummary")
|
||||
memory_summary_list.extend(processed_summary)
|
||||
|
||||
# 处理Statement
|
||||
statement = data.get('statement')
|
||||
if statement is not None:
|
||||
processed_statement = self._process_field_value(statement, "Statement")
|
||||
processed_statement = await self._process_field_value(statement, "Statement")
|
||||
statement_list.extend(processed_statement)
|
||||
|
||||
# 处理ExtractedEntity
|
||||
extracted_entity = data.get('ExtractedEntity')
|
||||
if extracted_entity is not None:
|
||||
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
extracted_entity_list.extend(processed_entity)
|
||||
|
||||
# 去重 - 现在处理的是字典列表,需要更智能的去重
|
||||
@@ -128,6 +131,8 @@ class MemoryEntityService:
|
||||
all_timeline_data = memory_summary_list + statement_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
# 如果需要翻译(非中文),对整个结果进行翻译
|
||||
|
||||
result = {
|
||||
"MemorySummary": memory_summary_list,
|
||||
"Statement": statement_list,
|
||||
@@ -233,7 +238,7 @@ class MemoryEntityService:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理字段值,支持字符串、列表等类型
|
||||
|
||||
@@ -251,13 +256,13 @@ class MemoryEntityService:
|
||||
# 如果是列表,处理每个元素
|
||||
for item in value:
|
||||
if self._is_valid_item(item):
|
||||
processed_item = self._process_single_item(item)
|
||||
processed_item = await self._process_single_item(item)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, dict):
|
||||
# 如果是字典,直接处理
|
||||
if self._is_valid_item(value):
|
||||
processed_item = self._process_single_item(value)
|
||||
processed_item = await self._process_single_item(value)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, str):
|
||||
@@ -304,7 +309,7 @@ class MemoryEntityService:
|
||||
return (str(item).strip() != '' and
|
||||
"MemorySummaryChunk" not in str(item))
|
||||
|
||||
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理单个项目
|
||||
|
||||
@@ -369,6 +374,117 @@ class MemoryEntityService:
|
||||
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
|
||||
return str(dt) if dt is not None else None
|
||||
|
||||
async def _translate_list(
|
||||
self,
|
||||
data_list: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
fields: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
翻译列表中每个字典的指定字段(并发有限度以降低整体延迟)
|
||||
|
||||
Args:
|
||||
data_list: 要翻译的字典列表
|
||||
model_id: 模型ID
|
||||
fields: 需要翻译的字段列表
|
||||
|
||||
Returns:
|
||||
翻译后的字典列表
|
||||
"""
|
||||
# 空列表或无字段时直接返回
|
||||
if not data_list or not fields:
|
||||
return data_list
|
||||
|
||||
import asyncio
|
||||
|
||||
# 并发限制,避免一次性发起过多请求
|
||||
# 可根据实际情况调整(建议 5-10)
|
||||
concurrency_limit = 5
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def translate_single_field(
|
||||
index: int,
|
||||
field: str,
|
||||
value: Any,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
翻译单个字段并返回 (索引, 字段名, 翻译结果)
|
||||
|
||||
Returns:
|
||||
(index, field, translated_value) 或 None(如果跳过)
|
||||
"""
|
||||
# 跳过空值
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
# 统一转成字符串再翻译,防止非字符串类型导致错误
|
||||
text = str(value)
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
# 调用 Translation_English 进行翻译
|
||||
# 注意:Translation_English 的参数顺序是 (model_id, text)
|
||||
translated = await Translation_English(model_id, text)
|
||||
|
||||
# 如果翻译结果为空,保留原值
|
||||
if translated is None or translated == "":
|
||||
return None
|
||||
|
||||
return index, field, translated
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}")
|
||||
return None
|
||||
|
||||
# 构造所有需要翻译的任务
|
||||
tasks = []
|
||||
for idx, item in enumerate(data_list):
|
||||
# 防御性检查:确保 item 是字典
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
for field in fields:
|
||||
if field not in item:
|
||||
continue
|
||||
|
||||
value = item.get(field)
|
||||
|
||||
# 对于 None 或空字符串的值,直接跳过,不创建任务
|
||||
if value is None or value == "":
|
||||
continue
|
||||
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
translate_single_field(idx, field, value)
|
||||
)
|
||||
)
|
||||
|
||||
# 如果没有需要翻译的任务,直接返回原列表
|
||||
if not tasks:
|
||||
return data_list
|
||||
|
||||
# 使用 gather 并发执行翻译任务(受 semaphore 限制)
|
||||
# return_exceptions=True 可以防止单个任务失败导致整体失败
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 创建深拷贝以避免修改原始数据
|
||||
translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list]
|
||||
|
||||
# 将翻译结果回填到列表
|
||||
for result in results:
|
||||
# 跳过 None 结果和异常
|
||||
if result is None or isinstance(result, Exception):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"翻译任务异常: {result}")
|
||||
continue
|
||||
|
||||
idx, field, translated = result
|
||||
|
||||
# 防御性检查索引范围
|
||||
if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict):
|
||||
translated_list[idx][field] = translated
|
||||
|
||||
return translated_list
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -426,15 +542,19 @@ class MemoryEmotion:
|
||||
# 如果解析失败,返回原始字符串
|
||||
return iso_string
|
||||
|
||||
async def get_emotion(self) -> Dict[str, Any]:
|
||||
async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]:
|
||||
"""
|
||||
获取情绪随时间变化数据
|
||||
|
||||
Args:
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
包含情绪数据的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}")
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}")
|
||||
|
||||
if self.table == 'Statement':
|
||||
results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id)
|
||||
@@ -450,6 +570,10 @@ class MemoryEmotion:
|
||||
# 转换Neo4j类型
|
||||
final_data = self._convert_neo4j_types(emotion_data)
|
||||
|
||||
# 如果需要翻译(非中文)
|
||||
if language_type != 'zh' and model_id and final_data:
|
||||
final_data = await self._translate_emotion_data(final_data, model_id)
|
||||
|
||||
logger.info(f"成功获取 {len(final_data)} 条情绪数据")
|
||||
|
||||
return final_data
|
||||
@@ -590,16 +714,14 @@ class MemoryInteraction:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
||||
if ori_data!=[]:
|
||||
# name = ori_data[0]['name']
|
||||
group_id = ori_data[0]['group_id']
|
||||
group_id = [i['group_id'] for i in ori_data][0]
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
||||
if not Space_User:
|
||||
return []
|
||||
user_id=Space_User[0]['id']
|
||||
|
||||
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
@@ -360,7 +360,9 @@ class UserMemoryService:
|
||||
async def get_cached_memory_insight(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
end_user_id: str,
|
||||
model_id: str,
|
||||
language_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的记忆洞察(四个维度)
|
||||
@@ -419,11 +421,18 @@ class UserMemoryService:
|
||||
key_findings_array = []
|
||||
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)")
|
||||
memory_insight=end_user.memory_insight
|
||||
behavior_pattern=end_user.behavior_pattern
|
||||
growth_trajectory=end_user.growth_trajectory
|
||||
if language_type!='zh':
|
||||
memory_insight=await Translation_English(model_id,memory_insight)
|
||||
behavior_pattern=await Translation_English(model_id,behavior_pattern)
|
||||
growth_trajectory=await Translation_English(model_id,growth_trajectory)
|
||||
return {
|
||||
"memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight
|
||||
"behavior_pattern": end_user.behavior_pattern,
|
||||
"memory_insight":memory_insight, # 总体概述存储在 memory_insight
|
||||
"behavior_pattern":behavior_pattern,
|
||||
"key_findings": key_findings_array, # 返回数组
|
||||
"growth_trajectory": end_user.growth_trajectory,
|
||||
"growth_trajectory": growth_trajectory,
|
||||
"updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at),
|
||||
"is_cached": True
|
||||
}
|
||||
@@ -457,7 +466,9 @@ class UserMemoryService:
|
||||
async def get_cached_user_summary(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
end_user_id: str,
|
||||
model_id:str,
|
||||
language_type:str="zh"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的用户摘要(四个部分)
|
||||
@@ -481,7 +492,6 @@ class UserMemoryService:
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
@@ -495,20 +505,29 @@ class UserMemoryService:
|
||||
}
|
||||
|
||||
# 检查是否有缓存数据(至少有一个字段不为空)
|
||||
user_summary=end_user.user_summary
|
||||
personality_traits=end_user.personality_traits
|
||||
core_values=end_user.core_values
|
||||
one_sentence_summary=end_user.one_sentence_summary
|
||||
if language_type!='zh':
|
||||
user_summary=await Translation_English(model_id, user_summary)
|
||||
personality_traits = await Translation_English(model_id, personality_traits)
|
||||
core_values = await Translation_English(model_id, core_values)
|
||||
one_sentence_summary = await Translation_English(model_id, one_sentence_summary)
|
||||
has_cache = any([
|
||||
end_user.user_summary,
|
||||
end_user.personality_traits,
|
||||
end_user.core_values,
|
||||
end_user.one_sentence_summary
|
||||
user_summary,
|
||||
personality_traits,
|
||||
core_values,
|
||||
one_sentence_summary
|
||||
])
|
||||
|
||||
if has_cache:
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
|
||||
return {
|
||||
"user_summary": end_user.user_summary,
|
||||
"personality": end_user.personality_traits,
|
||||
"core_values": end_user.core_values,
|
||||
"one_sentence": end_user.one_sentence_summary,
|
||||
"user_summary": user_summary,
|
||||
"personality": personality_traits,
|
||||
"core_values":core_values,
|
||||
"one_sentence": one_sentence_summary,
|
||||
"updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at),
|
||||
"is_cached": True
|
||||
}
|
||||
@@ -1367,7 +1386,6 @@ async def analytics_memory_types(
|
||||
|
||||
return memory_types
|
||||
|
||||
|
||||
async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
@@ -1557,7 +1575,7 @@ async def analytics_graph_data(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={len(nodes)}, edges={len(edges)}"
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
@@ -1606,11 +1624,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
|
||||
# 获取该节点类型的白名单字段
|
||||
allowed_fields = field_whitelist.get(label, [])
|
||||
|
||||
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
|
||||
# if not allowed_fields:
|
||||
# # 对于未定义的节点类型,只返回基本字段
|
||||
# allowed_fields = ["name", "created_at", "caption"]
|
||||
|
||||
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
|
||||
node_results = await (_neo4j_connector.execute_query(count_neo4j))
|
||||
# 提取白名单中的字段
|
||||
@@ -1618,13 +1632,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
for field in allowed_fields:
|
||||
if field in properties:
|
||||
value = properties[field]
|
||||
if str(field) == 'entity_type':
|
||||
if str(field) == 'entity_type':
|
||||
value=type_mapping.get(value,'')
|
||||
if str(field)=="emotion_type":
|
||||
value=EmotionType.EMOTION_MAPPING.get(value)
|
||||
if str(field)=="emotion_subject":
|
||||
if str(field)=="emotion_subject":
|
||||
value=EmotionSubject.SUBJECT_MAPPING.get(value)
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
|
||||
return filtered_props
|
||||
|
||||
Reference in New Issue
Block a user