diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 698f061d..d8479f97 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -728,9 +728,23 @@ async def draft_run_compare( from app.core.exceptions import ResourceNotFoundException raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) + # 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典 + agent_model_params = agent_cfg.model_parameters + if hasattr(agent_model_params, 'model_dump'): + agent_model_params = agent_model_params.model_dump() + elif not isinstance(agent_model_params, dict): + agent_model_params = {} + + # 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典 + item_model_params = model_item.model_parameters + if hasattr(item_model_params, 'model_dump'): + item_model_params = item_model_params.model_dump() + elif not isinstance(item_model_params, dict): + item_model_params = {} + merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) + **(agent_model_params or {}), + **(item_model_params or {}) } model_configs.append({ diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 2069dd66..dba52d0b 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -108,16 +108,23 @@ async def get_prompt_opt( service = PromptOptimizerService(db) async def event_generator(): - async for chunk in service.optimize_prompt( - tenant_id=current_user.tenant_id, - model_id=data.model_id, - session_id=session_id, - user_id=current_user.id, - current_prompt=data.current_prompt, - user_require=data.message - ): - # chunk 是 prompt 的增量内容 - yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" + yield "event:start\ndata: {}\n\n" + try: + async for chunk in service.optimize_prompt( + tenant_id=current_user.tenant_id, + model_id=data.model_id, + session_id=session_id, + user_id=current_user.id, + current_prompt=data.current_prompt, + user_require=data.message + ): + # chunk 是 prompt 的增量内容 + yield f"event:message\ndata: {json.dumps(chunk)}\n\n" + except Exception as e: + yield f"event:error\ndata: {json.dumps( + {"error": str(e)} + )}\n\n" + yield "event:end\ndata: {}\n\n" return StreamingResponse( event_generator(), diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index a7a6203d..adb199fb 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -1,4 +1,5 @@ import hashlib +import json import uuid from typing import Annotated from fastapi import APIRouter, Depends, Query, Request @@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService from app.services.app_chat_service import AppChatService, get_app_chat_service -from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release +from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) logger = get_business_logger() @@ -288,7 +289,7 @@ async def chat( password = None # Token 认证不需要密码 # end_user_id = user_id other_id = user_id - + # 提前验证和准备(在流式响应开始前完成) # 这样可以确保错误能正确返回,而不是在流式响应中间出错 from app.models.app_model import AppType @@ -364,6 +365,9 @@ async def chat( config = release.config or {} if not config.get("sub_agents"): raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING) + elif app_type == AppType.WORKFLOW: + # Multi-Agent 类型:验证多 Agent 配置 + pass else: raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) @@ -392,6 +396,8 @@ async def chat( if app_type == AppType.AGENT: # 流式返回 + agent_config = agent_config_4_app_release(release) + if payload.stream: # async def event_generator(): # async for event in service.chat_stream( @@ -424,7 +430,7 @@ async def chat( user_id= str(new_end_user.id), # 转换为字符串 variables=payload.variables, web_search=payload.web_search, - config=payload.agent_config, + config=agent_config, memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id @@ -467,6 +473,7 @@ async def chat( ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: + # config = workflow_config_4_app_release(release) config = multi_agent_config_4_app_release(release) if payload.stream: async def event_generator(): @@ -551,8 +558,71 @@ async def chat( # ) # return success(data=conversation_schema.ChatResponse(**result)) + elif app_type == AppType.WORKFLOW: + + config = workflow_config_4_app_release(release) + if payload.stream: + async def event_generator(): + async for event in app_chat_service.workflow_chat_stream( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=new_end_user.id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=release.app_id, + workspace_id=workspace_id + ): + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 多 Agent 非流式返回 + result = await app_chat_service.workflow_chat( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=new_end_user.id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=release.app_id, + workspace_id=workspace_id + ) + logger.debug( + "工作流试运行返回结果", + extra={ + "result_type": str(type(result)), + "has_response": "response" in result if isinstance(result, dict) else False + } + ) + return success( + data=result, + msg="工作流任务执行成功" + ) + # return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) + else: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 5a78a28b..54af0b57 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -1,4 +1,5 @@ """App 服务接口 - 基于 API Key 认证""" +import json from typing import Annotated from fastapi import APIRouter, Depends, Request, Body @@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth from app.services import workspace_service from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.conversation_service import ConversationService, get_conversation_service -from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release +from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release from app.services.app_service import get_app_service, AppService router = APIRouter(prefix="/app", tags=["V1 - App API"]) @@ -226,22 +227,29 @@ async def chat( return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.WORKFLOW: # 多 Agent 流式返回 - config = dict_to_workflow_config(app.current_release.config,app.id) + config = workflow_config_4_app_release(app.current_release) if payload.stream: async def event_generator(): async for event in app_chat_service.workflow_chat_stream( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=end_user_id, # 转换为字符串 + user_id=new_end_user.id, # 转换为字符串 variables=payload.variables, config=config, - web_search=web_search, - memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + app_id=app.app_id, + workspace_id=workspace_id ): - yield event + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message return StreamingResponse( event_generator(), @@ -253,21 +261,32 @@ async def chat( } ) - # 非流式返回 + # 多 Agent 非流式返回 result = await app_chat_service.workflow_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=end_user_id, # 转换为字符串 + user_id=new_end_user.id, # 转换为字符串 variables=payload.variables, config=config, - web_search=web_search, - memory=memory, + web_search=payload.web_search, + memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + app_id=app.app_id, + workspace_id=workspace_id + ) + logger.debug( + "工作流试运行返回结果", + extra={ + "result_type": str(type(result)), + "has_response": "response" in result if isinstance(result, dict) else False + } + ) + return success( + data=result, + msg="工作流任务执行成功" ) - - return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) else: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index 15c50601..da12cbf6 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -11,9 +11,9 @@ from app.db import get_db from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode +from app.core.api_key_utils import timestamp_to_datetime from app.services.user_memory_service import ( UserMemoryService, - analytics_node_statistics, analytics_memory_types, analytics_graph_data, ) @@ -41,24 +41,27 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, # 使用 end_user_id + end_user_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: - """获取缓存的记忆洞察报告""" - api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}") +) -> dict: + """ + 获取缓存的记忆洞察报告 + + 此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。 + 如需生成新的洞察报告,请使用专门的生成接口。 + """ + api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_memory_insight(db, end_user_id) if result["is_cached"]: - # 缓存存在,返回缓存数据 api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") return success(data=result, msg="查询成功") else: - # 缓存不存在,返回提示消息 api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=result, msg="数据尚未生成") except Exception as e: api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) @@ -66,24 +69,27 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, # 使用 end_user_id + end_user_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: - """获取缓存的用户摘要""" - api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}") +) -> dict: + """ + 获取缓存的用户摘要 + + 此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。 + 如需生成新的用户摘要,请使用专门的生成接口。 + """ + api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_user_summary(db, end_user_id) if result["is_cached"]: - # 缓存存在,返回缓存数据 api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") return success(data=result, msg="查询成功") else: - # 缓存不存在,返回提示消息 api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=result, msg="数据尚未生成") except Exception as e: api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) @@ -351,7 +357,7 @@ async def update_end_user_profile( if 'hire_date' in update_data: hire_date_timestamp = update_data['hire_date'] if hire_date_timestamp is not None: - update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp) + update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) # 如果是 None,保持 None(允许清空) for field, value in update_data.items(): diff --git a/api/app/core/memory/analytics/__init__.py b/api/app/core/memory/analytics/__init__.py index 06aeaed3..6811ff8f 100644 --- a/api/app/core/memory/analytics/__init__.py +++ b/api/app/core/memory/analytics/__init__.py @@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system. Available functions: - get_hot_memory_tags: Get hot memory tags by frequency -- MemoryInsight: Generate memory insight reports - get_recent_activity_stats: Get recent activity statistics -- generate_user_summary: Generate user summary + +Note: MemoryInsight and generate_user_summary have been moved to +app.services.user_memory_service for better architecture. """ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.analytics.memory_insight import MemoryInsight from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats -from app.core.memory.analytics.user_summary import generate_user_summary __all__ = [ "get_hot_memory_tags", - "MemoryInsight", "get_recent_activity_stats", - "generate_user_summary", ] diff --git a/api/app/core/memory/analytics/memory_insight.py b/api/app/core/memory/analytics/memory_insight.py deleted file mode 100644 index 39746e58..00000000 --- a/api/app/core/memory/analytics/memory_insight.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -This module provides the MemoryInsight class for analyzing user memory data. - -MemoryInsight 是一个工具类,提供基础的数据获取和分析功能: -- get_domain_distribution(): 获取记忆领域分布 -- get_active_periods(): 获取活跃时段 -- get_social_connections(): 获取社交关联 - -业务逻辑(如生成洞察报告)应该在服务层(user_memory_service.py)中实现。 - -This script can be executed directly to test the memory insight generation for a test user. -""" - -import asyncio -import json -import os -import sys -from collections import Counter -from datetime import datetime - -# To run this script directly, we need to add the src directory to the Python path -# to resolve the inconsistent imports in other modules. -src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -if src_path not in sys.path: - sys.path.insert(0, src_path) - -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService -from pydantic import BaseModel, Field - -#TODO: Fix this - -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") - -# 定义用于LLM结构化输出的Pydantic模型 -class TagClassification(BaseModel): - """ - Represents the classification of a tag into a specific domain. - """ - - domain: str = Field( - ..., - description="The domain the tag belongs to, chosen from the predefined list.", - examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"], - ) - -class InsightReport(BaseModel): - """ - Represents the final insight report generated by the LLM. - """ - - report: str = Field( - ..., - description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.", - ) - - -class MemoryInsight: - """ - Provides insights into user memories by analyzing various aspects of their data. - """ - - def __init__(self, user_id: str): - self.user_id = user_id - self.neo4j_connector = Neo4jConnector() - - # Get config_id using get_end_user_connected_config - with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - connected_config = get_end_user_connected_config(user_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - - async def close(self): - """关闭数据库连接。""" - await self.neo4j_connector.close() - - async def get_domain_distribution(self) -> dict[str, float]: - """ - Calculates the distribution of memory domains based on hot tags. - """ - hot_tags = await get_hot_memory_tags(self.user_id) - if not hot_tags: - return {} - - domain_counts = Counter() - for tag, _ in hot_tags: - prompt = f"""请将以下标签归类到最合适的领域中。 - -可选领域及其关键词: -- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等 -- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等 -- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等 -- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等 -- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等 -- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等 -- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等 -- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等 -- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等 -- 其他:确实无法归入以上任何类别的内容 - -标签: {tag} - -分析步骤: -1. 仔细理解标签的核心含义和使用场景 -2. 对比各个领域的关键词,找到最匹配的领域 -3. 特别注意: - - 历史、科学、文化等知识性内容应归类为"学习" - - 学校、课程、考试等正式教育场景应归类为"教育" - - 只有在标签完全不属于上述9个具体领域时,才选择"其他" -4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他" - -请直接返回最合适的领域名称。""" - messages = [ - {"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"}, - {"role": "user", "content": prompt} - ] - # 直接调用并等待结果 - classification = await self.llm_client.response_structured( - messages=messages, - response_model=TagClassification, - ) - if classification and hasattr(classification, 'domain') and classification.domain: - domain_counts[classification.domain] += 1 - - total_tags = sum(domain_counts.values()) - if total_tags == 0: - return {} - - domain_distribution = { - domain: count / total_tags for domain, count in domain_counts.items() - } - return dict( - sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True) - ) - - async def get_active_periods(self) -> list[int]: - """ - Identifies the top 2 most active months for the user. - Only returns months if there is valid and diverse time data. - - This method checks if the time data represents real user memory timestamps - rather than auto-generated system timestamps by verifying: - 1. Time data exists and is parseable - 2. Time data is distributed across multiple months (not concentrated in 1-2 months) - """ - query = f""" - MATCH (d:Dialogue) - WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> '' - RETURN d.created_at AS creation_time - """ - records = await self.neo4j_connector.execute_query(query) - - if not records: - return [] - - month_counts = Counter() - valid_dates_count = 0 - for record in records: - creation_time_str = record.get("creation_time") - if not creation_time_str: - continue - try: - # 尝试解析时间字符串 - dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00")) - month_counts[dt_object.month] += 1 - valid_dates_count += 1 - except (ValueError, TypeError, AttributeError): - # 如果解析失败,跳过这条记录 - continue - - # 如果没有有效的时间数据,返回空列表 - if not month_counts or valid_dates_count == 0: - return [] - - # 检查时间分布是否过于集中(可能是批量导入的数据) - # 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间 - unique_months = len(month_counts) - if unique_months <= 2: - # 只有1-2个月有数据,很可能是批量导入 - most_common_count = month_counts.most_common(1)[0][1] - if most_common_count / valid_dates_count > 0.8: - # 超过80%集中在一个月,认为是系统时间戳 - return [] - - # 如果时间分布较为分散(3个月以上),认为是真实时间数据 - if unique_months >= 3: - most_common_months = month_counts.most_common(2) - return [month for month, _ in most_common_months] - - # 2个月的情况,检查是否分布均匀 - if unique_months == 2: - counts = list(month_counts.values()) - # 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据 - ratio = min(counts) / max(counts) - if ratio > 0.3: - most_common_months = month_counts.most_common(2) - return [month for month, _ in most_common_months] - - # 其他情况返回空列表 - return [] - - async def get_social_connections(self) -> dict | None: - """ - Finds the user with whom the most memories are shared. - 使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。 - """ - # 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆 - query = f""" - MATCH (c1:Chunk {{group_id: '{self.user_id}'}}) - OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) - OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) - WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL - WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements - WHERE common_statements > 0 - RETURN other_user_id, common_statements - ORDER BY common_statements DESC - LIMIT 1 - """ - records = await self.neo4j_connector.execute_query(query) - if not records or not records[0].get("other_user_id"): - return None - - most_connected_user = records[0]["other_user_id"] - common_memories_count = records[0]["common_statements"] - - # 使用 Chunk 的时间范围 - time_range_query = f""" - MATCH (c:Chunk) - WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}'] - RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time - """ - time_records = await self.neo4j_connector.execute_query(time_range_query) - start_year, end_year = "N/A", "N/A" - if time_records and time_records[0]["start_time"]: - start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year - end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year - - return { - "user_id": most_connected_user, - "common_memories_count": common_memories_count, - "time_range": f"{start_year}-{end_year}", - } - - async def close(self): - """ - Closes the database connection. - """ - await self.neo4j_connector.close() - - -async def main(): - """ - Initializes and runs the memory insight analysis for a test user. - """ - # 默认从环境变量读取 - test_user_id = DEFAULT_GROUP_ID - print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n") - - try: - # 使用服务层函数生成报告 - from app.services.user_memory_service import analytics_memory_insight_report - - result = await analytics_memory_insight_report(end_user_id=test_user_id) - report = result.get("report", "") - - print("--- 记忆洞察报告 ---") - print(report) - print("---------------------") - - # 将结果写入统一的 User-Dashboard.json,使用全局配置路径 - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_dir = settings.MEMORY_OUTPUT_DIR - try: - os.makedirs(output_dir, exist_ok=True) - except Exception: - pass - dashboard_path = os.path.join(output_dir, "User-Dashboard.json") - existing = {} - if os.path.exists(dashboard_path): - with open(dashboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["memory_insight"] = { - "group_id": test_user_id, - "report": report - } - with open(dashboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {dashboard_path} -> memory_insight") - except Exception as e: - print(f"写入 User-Dashboard.json 失败: {e}") - except Exception as e: - print(f"生成报告时出错: {e}") - - -if __name__ == "__main__": - # This setup allows running the async main function - if sys.platform.startswith('win') and sys.version_info >= (3, 8): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/api/app/core/memory/analytics/user_summary.py b/api/app/core/memory/analytics/user_summary.py deleted file mode 100644 index f0283993..00000000 --- a/api/app/core/memory/analytics/user_summary.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Generate a concise "关于我" style user summary using data from Neo4j -and the existing LLM configuration (mirrors hot_memory_tags.py setup). - -Usage: - python -m analytics.user_summary --user_id -""" - -import asyncio -import json -import os -import sys -from dataclasses import dataclass -from typing import List, Tuple - -# Ensure absolute imports work whether executed directly or via module -try: - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) - src_path = os.path.join(project_root, 'src') - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -except Exception: - pass - -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - -#TODO: Fix this - -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") - - -@dataclass -class StatementRecord: - statement: str - created_at: str | None - - -class UserSummary: - """Builds a textual user summary for a given user/group id.""" - - def __init__(self, user_id: str): - self.user_id = user_id - self.connector = Neo4jConnector() - - # Get config_id using get_end_user_connected_config - with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - connected_config = get_end_user_connected_config(user_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - self.llm = factory.get_llm_client(DEFAULT_LLM_ID) - - async def close(self): - await self.connector.close() - - async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service - """Fetch recent statements authored by the user/group for context.""" - query = ( - "MATCH (s:Statement) " - "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " - "RETURN s.statement AS statement, s.created_at AS created_at " - "ORDER BY created_at DESC LIMIT $limit" - ) - rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) - records: List[StatementRecord] = [] - for r in rows: - try: - records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at"))) - except Exception: - continue - return records - - async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]: - """Reuse hot tag logic to get meaningful entities and their frequencies.""" - # get_hot_memory_tags internally filters out non-meaningful nouns with LLM - return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service - - -async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless - """ - 生成用户摘要的便捷函数 - - Args: - user_id: 可选的用户ID - - Returns: - 用户摘要字符串 - """ - # 导入服务层函数 - from app.services.user_memory_service import analytics_user_summary - - # 调用服务层函数 - result = await analytics_user_summary(user_id) - return result.get("summary", "") - - -if __name__ == "__main__": - print("开始生成用户摘要…") - try: - # 直接使用 runtime.json 中的 group_id - summary = asyncio.run(generate_user_summary()) - print("\n— 用户摘要 —\n") - print(summary) - - # 将结果写入统一的 User-Dashboard.json - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_dir = settings.MEMORY_OUTPUT_DIR - try: - os.makedirs(output_dir, exist_ok=True) - except Exception: - pass - dashboard_path = os.path.join(output_dir, "User-Dashboard.json") - existing = {} - if os.path.exists(dashboard_path): - with open(dashboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["user_summary"] = { - "group_id": DEFAULT_GROUP_ID, - "summary": summary - } - with open(dashboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {dashboard_path} -> user_summary") - except Exception as e: - print(f"写入 User-Dashboard.json 失败: {e}") - except Exception as e: - print(f"生成摘要失败: {e}") - print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。") diff --git a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 index 373ab31e..2f452c53 100644 --- a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 @@ -85,33 +85,21 @@ Example Output: ===End of Example=== -===Reflection Process=== +===Internal Quality Checks (DO NOT OUTPUT)=== -After generating the profile, perform the following self-review steps: +Before generating your final output, internally verify: +1. All content is grounded in provided data (no fabrication) +2. Format follows the specified structure with correct headers +3. Tone is objective, third-person, and neutral +4. All four sections are complete and within character limits -**Step 1: Data Grounding Check** -- Verify all statements are supported by the provided entities and statements -- Ensure no fabricated or speculated information is included -- Confirm all claims can be traced back to the input data - -**Step 2: Format Compliance** -- Verify each section follows the specified format with section headers -- Check character count limits for each section -- Ensure proper use of section markers (【】) - -**Step 3: Tone and Style Review** -- Confirm objective third-person perspective is maintained -- Check for excessive adjectives or empty phrases -- Verify neutral and restrained tone throughout - -**Step 4: Completeness Check** -- Ensure all four sections are present and complete -- Verify each section addresses its specific focus area -- Confirm the one-sentence summary effectively captures the user's essence +**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.** ===Output Requirements=== +**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.** + **LANGUAGE REQUIREMENT:** - The output language should ALWAYS be Chinese (Simplified) - All section content must be in Chinese @@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps: - Content follows immediately after the header - Sections are separated by blank lines - Strictly adhere to character limits for each section +- **DO NOT include any text after the 【一句话总结】 section** +- **DO NOT output reflection steps, self-review, or verification notes** diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8eb31fb4..a1ec2e1d 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -29,7 +29,7 @@ class WorkflowState(TypedDict): # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list - looping: bool + looping: Annotated[bool, lambda x, y: x and y] # Input variables (passed from configured variables) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 55919998..4374d847 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode): retries -= 1 if retries > 0: await asyncio.sleep(self.typed_config.retry.retry_interval / 1000) + elif self.typed_config.error_handle.method == HttpErrorHandle.NONE: + raise e + except Exception as e: + raise RuntimeError(f"HTTP request node exception: {e}") else: match self.typed_config.error_handle.method: - case HttpErrorHandle.NONE: - logger.warning( - f"Node {self.node_id}: HTTP request failed, returning error response" - ) - return HttpRequestNodeOutput( - body="", - status_code=resp.status_code, - headers=resp.headers, - ).model_dump() case HttpErrorHandle.DEFAULT: logger.warning( f"Node {self.node_id}: HTTP request failed, returning default result" @@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode): f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) return "ERROR" + raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index e12c6224..d9caae7e 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode): rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, indices=indices, score_threshold=kb_config.similarity_threshold) - # Deduplicate hybrid retrieval results + # Deduplicate hy brid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) vector_service.reranker = self.get_reranker_model() rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) case _: raise RuntimeError("Unknown retrieval type") vector_service.reranker = self.get_reranker_model() + # TODO:其他重排序方式支持 final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) - return [chunk.model_dump() for chunk in final_rs] + return [chunk.page_content for chunk in final_rs] diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index da94482b..8498fc38 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -1,5 +1,7 @@ """LLM 节点配置""" +from typing import Any + from pydantic import BaseModel, Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType @@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti class MessageConfig(BaseModel): """消息配置""" - + role: str = Field( ..., description="消息角色:system, user, assistant" ) - + content: str = Field( ..., description="消息内容,支持模板变量,如:{{ sys.message }}" ) - + @field_validator("role") @classmethod def validate_role(cls, v: str) -> str: @@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig): 1. 简单模式:使用 prompt 字段 2. 消息模式:使用 messages 字段(推荐) """ - + model_id: str = Field( ..., description="模型配置 ID" ) - + + context: Any = Field( + default="", + description="上下文" + ) + # 简单模式 prompt: str | None = Field( default=None, description="提示词模板(简单模式),支持变量引用" ) - + # 消息模式(推荐) messages: list[MessageConfig] | None = Field( default=None, description="消息列表(消息模式),支持多轮对话" ) - + # 模型参数 temperature: float | None = Field( default=0.7, @@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig): le=2.0, description="温度参数,控制输出的随机性" ) - + max_tokens: int | None = Field( default=1000, ge=1, le=32000, description="最大生成 token 数" ) - + top_p: float | None = Field( default=None, ge=0.0, le=1.0, description="Top-p 采样参数" ) - + frequency_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="频率惩罚" ) - + presence_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="存在惩罚" ) - + # 输出变量定义 output_variables: list[VariableDefinition] = Field( default_factory=lambda: [ @@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig): ], description="输出变量定义(自动生成,通常不需要修改)" ) - + @field_validator("messages", "prompt") @classmethod def validate_input_mode(cls, v, info): """验证输入模式:prompt 和 messages 至少有一个""" # 这个验证在 model_validator 中更合适 return v - + class Config: json_schema_extra = { "examples": [ diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 65826d84..334229f7 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -5,15 +5,17 @@ LLM 节点实现 """ import logging +import re from typing import Any from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.db import get_db_context from app.models import ModelType from app.services.model_service import ModelConfigService - + from app.core.exceptions import BusinessException from app.core.error_codes import BizCode @@ -63,8 +65,15 @@ class LLMNode(BaseNode): - user/human: 用户消息(HumanMessage) - ai/assistant: AI 消息(AIMessage) """ - - def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = LLMNodeConfig(**self.config) + + def _render_context(self, message,state): + context = f"{self._render_template(self.typed_config.context, state)}" + return re.sub(r"{{context}}", context, message) + + def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: @@ -76,15 +85,16 @@ class LLMNode(BaseNode): # 1. 处理消息格式(优先使用 messages) messages_config = self.config.get("messages") - + if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: role = msg_config.get("role", "user").lower() content_template = msg_config.get("content", "") + content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) - + # 根据角色创建对应的消息对象 if role == "system": messages.append(SystemMessage(content=content)) @@ -95,7 +105,7 @@ class LLMNode(BaseNode): else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append(HumanMessage(content=content)) - + prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -106,17 +116,17 @@ class LLMNode(BaseNode): model_id = self.config.get("model_id") if not model_id: raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") - + # 3. 在 with 块内完成所有数据库操作和数据提取 with get_db_context() as db: config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - - if not config: + + if not config: raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) - + if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - + # 在 Session 关闭前提取所有需要的数据 api_config = config.api_keys[0] model_name = api_config.model_name @@ -124,26 +134,26 @@ class LLMNode(BaseNode): api_key = api_config.api_key api_base = api_config.api_base model_type = config.type - + # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True extra_params = {"streaming": stream} if stream else {} - + llm = RedBearLLM( RedBearModelConfig( model_name=model_name, - provider=provider, + provider=provider, api_key=api_key, base_url=api_base, extra_params=extra_params - ), + ), type=ModelType(model_type) ) - + logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") - + return llm, prompt_or_messages - + async def execute(self, state: WorkflowState) -> AIMessage: """非流式执行 LLM 调用 @@ -153,10 +163,10 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - llm, prompt_or_messages = self._prepare_llm(state,True) - + llm, prompt_or_messages = self._prepare_llm(state, True) + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") - + # 调用 LLM(支持字符串或消息列表) response = await llm.ainvoke(prompt_or_messages) # 提取内容 @@ -164,16 +174,16 @@ class LLMNode(BaseNode): content = response.content else: content = str(response) - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") - + # 返回 AIMessage(包含响应元数据) return response if isinstance(response, AIMessage) else AIMessage(content=content) - + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录)""" _, prompt_or_messages = self._prepare_llm(state) - + return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ @@ -186,13 +196,13 @@ class LLMNode(BaseNode): "max_tokens": self.config.get("max_tokens") } } - + def _extract_output(self, business_result: Any) -> str: """从 AIMessage 中提取文本内容""" if isinstance(business_result, AIMessage): return business_result.content return str(business_result) - + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """从 AIMessage 中提取 token 使用情况""" if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): @@ -204,7 +214,7 @@ class LLMNode(BaseNode): "total_tokens": usage.get('total_tokens', 0) } return None - + async def execute_stream(self, state: WorkflowState): """流式执行 LLM 调用 @@ -215,26 +225,26 @@ class LLMNode(BaseNode): 文本片段(chunk)或完成标记 """ from langgraph.config import get_stream_writer - + llm, prompt_or_messages = self._prepare_llm(state, True) - + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") - + # 检查是否有注入的 End 节点前缀配置 writer = get_stream_writer() end_prefix = getattr(self, '_end_node_prefix', None) - + logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") if end_prefix: logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") - + if end_prefix: # 渲染前缀(可能包含其他变量) try: rendered_prefix = self._render_template(end_prefix, state) logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") - + # 提前发送 End 节点的前缀(使用 "message" 类型) writer({ "type": "message", # End 相关的内容都是 message 类型 @@ -246,12 +256,12 @@ class LLMNode(BaseNode): }) except Exception as e: logger.warning(f"渲染/发送 End 节点前缀失败: {e}") - + # 累积完整响应 full_response = "" last_chunk = None chunk_count = 0 - + # 调用 LLM(流式,支持字符串或消息列表) async for chunk in llm.astream(prompt_or_messages): # 提取内容 @@ -259,18 +269,18 @@ class LLMNode(BaseNode): content = chunk.content else: content = str(chunk) - + # 只有当内容不为空时才处理 if content: full_response += content last_chunk = chunk chunk_count += 1 - + # 流式返回每个文本片段 yield content - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - + # 构建完整的 AIMessage(包含元数据) if isinstance(last_chunk, AIMessage): final_message = AIMessage( @@ -279,6 +289,6 @@ class LLMNode(BaseNode): ) else: final_message = AIMessage(content=full_response) - + # yield 完成标记 yield {"__final__": True, "result": final_message} diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 09c9fc68..bb2366f6 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode): return await MemoryAgentService().read_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, search_switch=self.typed_config.search_switch, history=[], @@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode): return await MemoryAgentService().write_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, db=db, storage_type="neo4j", diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 67f53801..b0f2c28d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" @@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode): f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" ) # 若分类列表为空,返回默认unknown分支,否则返回CASE1 - return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" + if category_count > 0: + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } try: llm = self._get_llm_instance() @@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - return f"CASE{category_names.index(category) + 1}" + return { + "class_name": category, + "output": f"CASE{category_names.index(category) + 1}", + } except Exception as e: logger.error( f"节点 {self.node_id} 分类执行异常:{str(e)}", @@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode): ) # 异常时返回默认分支,保证工作流容错性 if category_count > 0: - return DEFAULT_EMPTY_QUESTION_CASE - return "unknown" + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } diff --git a/api/app/core/workflow/nodes/tool/config.py b/api/app/core/workflow/nodes/tool/config.py index 487efae2..d3b1a644 100644 --- a/api/app/core/workflow/nodes/tool/config.py +++ b/api/app/core/workflow/nodes/tool/config.py @@ -1,4 +1,6 @@ from pydantic import Field +from typing import Any + from app.core.workflow.nodes.base_config import BaseNodeConfig @@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig): """工具节点配置""" tool_id: str = Field(..., description="工具ID") - tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") + tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 993a3804..e1b5f380 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -1,5 +1,5 @@ import logging -import uuid +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -9,6 +9,8 @@ from app.db import get_db_read logger = logging.getLogger(__name__) +TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}") + class ToolNode(BaseNode): """工具节点""" @@ -25,25 +27,33 @@ class ToolNode(BaseNode): # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: - workflow_id = self.get_variable("sys.workflow_id", state) - if workflow_id: + workspace_id = self.get_variable("sys.workspace_id", state) + if workspace_id: from app.repositories.tool_repository import ToolRepository with get_db_read() as db: - tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id) if not tenant_id: - tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097") - # logger.error(f"节点 {self.node_id} 缺少租户ID") - # return {"error": "缺少租户ID"} + logger.error(f"节点 {self.node_id} 缺少租户ID") + return { + "success": False, + "data": "缺少租户ID" + } # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - rendered_value = self._render_template(param_template, state) + if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, state) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + # 非模板参数(数字/布尔/普通字符串)直接保留原值 + rendered_value = param_template rendered_parameters[param_name] = rendered_value logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") - print(self.typed_config.tool_id) # 执行工具 with get_db_read() as db: @@ -54,7 +64,7 @@ class ToolNode(BaseNode): tenant_id=tenant_id, user_id=user_id ) - print(result) + if result.success: logger.info(f"节点 {self.node_id} 工具执行成功") return { @@ -66,7 +76,7 @@ class ToolNode(BaseNode): logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") return { "success": False, - "error": result.error, + "data": result.error, "error_code": result.error_code, "execution_time": result.execution_time } \ No newline at end of file diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 00358d91..6daf415d 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -87,10 +87,11 @@ class WorkflowValidator: return graphs @classmethod - def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: + def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]: """验证工作流配置 Args: + publish: 发布验证标识 workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 Returns: @@ -114,7 +115,7 @@ class WorkflowValidator: graphs = cls.get_subgraph(workflow_config) logger.info(graphs) - for graph in graphs: + for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) variables = graph.get("variables", []) @@ -125,10 +126,11 @@ class WorkflowValidator: elif len(start_nodes) > 1: errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") - # 2. 验证 end 节点(至少一个) - end_nodes = [n for n in nodes if n.get("type") == NodeType.END] - if len(end_nodes) == 0: - errors.append("工作流必须至少有一个 end 节点") + if index == len(graphs) - 1: + # 2. 验证 主图end 节点(至少一个) + end_nodes = [n for n in nodes if n.get("type") == NodeType.END] + if len(end_nodes) == 0: + errors.append("工作流必须至少有一个 end 节点") # 3. 验证节点 ID 唯一性 node_ids = [n.get("id") for n in nodes] @@ -159,15 +161,17 @@ class WorkflowValidator: elif target not in node_id_set: errors.append(f"边 #{i} 的 target 节点不存在: {target}") - # 6. 验证所有节点可达(从 start 节点出发) - if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( - start_nodes[0]["id"], - edges - ) - unreachable = node_id_set - reachable - if unreachable: - errors.append(f"以下节点无法从 start 节点到达: {unreachable}") + if publish: + # 仅在发布时验证所有节点可达 + # 6. 验证所有节点可达(从 start 节点出发) + if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 + reachable = WorkflowValidator._get_reachable_nodes( + start_nodes[0]["id"], + edges + ) + unreachable = node_id_set - reachable + if unreachable: + errors.append(f"以下节点无法从 start 节点到达: {unreachable}") # 7. 检测循环依赖(非 loop 节点) if not errors: # 只有在前面验证通过时才检查循环 @@ -288,7 +292,7 @@ class WorkflowValidator: (is_valid, errors): 是否有效和错误列表 """ # 先执行基础验证 - is_valid, errors = WorkflowValidator.validate(workflow_config) + is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True) if not is_valid: return False, errors diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 373de92c..a4645791 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -4,6 +4,8 @@ from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship from app.db import Base +from app.models.multi_agent_model import PydanticType +from app.schemas import ModelParameters class AgentConfig(Base): @@ -17,14 +19,17 @@ class AgentConfig(Base): # Agent 行为配置 system_prompt = Column(Text, nullable=True, comment="系统提示词") default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID") - + # 结构化配置(直接存储 JSON) - model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)") + # model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)") + model_parameters = Column(PydanticType(ModelParameters), nullable=True, + comment="模型参数配置(temperature、max_tokens等)") + knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") tools = Column(JSON, default=dict, nullable=True, comment="工具配置") - + # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等") @@ -41,4 +46,4 @@ class AgentConfig(Base): parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents") def __repr__(self): - return f"" \ No newline at end of file + return f"" diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 3aa7b16e..257910c3 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -38,6 +38,33 @@ class ToolRepository: return result[0] if result else None + @staticmethod + def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]: + """ + 根据空间ID获取tenant_id + + Args: + db: 数据库会话 + workspace_id: 空间ID + + Returns: + tenant_id或None + """ + from app.models.workspace_model import Workspace + + tenant_id = db.query(Workspace.tenant_id).filter( + Workspace.id == workspace_id + ).scalar() + + if tenant_id is not None and not isinstance(tenant_id, uuid.UUID): + # 兼容数据库中字段类型不匹配的情况(比如存储为字符串) + try: + tenant_id = uuid.UUID(tenant_id) + except (ValueError, TypeError): + return None + + return tenant_id + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 262c1c04..3ab14157 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -86,7 +86,12 @@ class AgentConfigConverter: # 1. 解析模型参数配置 if model_parameters: from app.schemas.app_schema import ModelParameters - result["model_parameters"] = ModelParameters(**model_parameters) + if isinstance(model_parameters, ModelParameters): + result["model_parameters"] = model_parameters + elif isinstance(model_parameters, dict): + result["model_parameters"] = ModelParameters(**model_parameters) + else: + result["model_parameters"] = ModelParameters() # 2. 解析知识库检索配置 if knowledge_retrieval: diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index efcf318d..6b7b3103 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -9,15 +9,18 @@ from fastapi import Depends from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.db import get_db -from app.models import MultiAgentConfig, AgentConfig +from app.db import get_db, get_db_context +from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -184,7 +187,7 @@ class AppChatService: model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") + system_prompt = config.system_prompt if variables: system_prompt_rendered = render_prompt_message( system_prompt, @@ -197,7 +200,7 @@ class AppChatService: tools = [] # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") + knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] @@ -208,13 +211,13 @@ class AppChatService: # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.get("memory", {}) + memory_config = config.memory if memory_config.get("enabled") and user_id: memory_flag = True memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.get("tools") + web_tools = config.tools web_search_choice = web_tools.get("web_search", {}) web_search_enable = web_search_choice.get("enabled", False) if web_search == True: @@ -230,7 +233,7 @@ class AppChatService: ) # 获取模型参数 - model_parameters = config.get("model_parameters", {}) + model_parameters = config.model_parameters # 创建 LangChain Agent agent = LangChainAgent( @@ -479,7 +482,9 @@ class AppChatService: self, message: str, conversation_id: uuid.UUID, - config: AgentConfig, + config: WorkflowConfig, + app_id: uuid.UUID, + workspace_id: uuid.UUID, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, @@ -488,281 +493,159 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" + workflow_service = WorkflowService(self.db) - start_time = time.time() - config_id = None + input_data = {"message":message, "variables": variables, + "conversation_id": str(conversation_id)} + inconfig = workflow_service.get_workflow_config(app_id) - if variables is None: - variables = {} + # 2. 创建执行记录 + execution = workflow_service.create_execution( + workflow_config_id=inconfig.id, + app_id=app_id, + trigger_type="manual", + triggered_by=None, + conversation_id=conversation_id, + input_data=input_data + ) - # 获取模型配置ID - model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + + # 5. 执行工作流 + from app.core.workflow.executor import execute_workflow + + try: + # 更新状态为运行中 + workflow_service.update_execution_status(execution.execution_id, "running") + + result = await execute_workflow( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id=str(workspace_id), + user_id=user_id ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - memory_flag = False - if memory == True: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } + # 更新执行结果 + if result.get("status") == "completed": + workflow_service.update_execution_status( + execution.execution_id, + "completed", + output_data=result.get("node_outputs", {}) + ) + else: + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=result.get("error") ) - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) + # 返回增强的响应结构 + return { + "execution_id": execution.execution_id, + "status": result.get("status"), + "output": result.get("output"), # 最终输出(字符串) + "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) + "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID + "error_message": result.get("error"), + "elapsed_time": result.get("elapsed_time"), + "token_usage": result.get("token_usage") + } - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - - ) - - # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) + except Exception as e: + logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) + raise BusinessException( + code=BizCode.INTERNAL_ERROR, + message=f"工作流执行失败: {str(e)}" ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 调用 Agent - result = await agent.chat( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag - ) - - # 保存消息 - self.conversation_service.save_conversation_messages( - conversation_id=conversation_id, - user_message=message, - assistant_message=result["content"] - ) - - elapsed_time = time.time() - start_time - - return { - "conversation_id": conversation_id, - "message": result["content"], - "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }), - "elapsed_time": elapsed_time - } async def workflow_chat_stream( self, message: str, conversation_id: uuid.UUID, - config: AgentConfig, + config: WorkflowConfig, + app_id: uuid.UUID, + workspace_id: uuid.UUID, user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: """聊天(流式)""" + workflow_service = WorkflowService(self.db) + input_data = {"message": message, "variables": variables, + "conversation_id": str(conversation_id)} + inconfig = workflow_service.get_workflow_config(app_id) + # 2. 创建执行记录 + execution = workflow_service.create_execution( + workflow_config_id=inconfig.id, + app_id=app_id, + trigger_type="manual", + triggered_by=None, + conversation_id=conversation_id, + input_data=input_data + ) + + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + + # 5. 流式执行工作流 try: - start_time = time.time() - config_id = None + # 更新状态为运行中 + workflow_service.update_execution_status(execution.execution_id, "running") - if variables is None: - variables = {} - # 获取模型配置ID - model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables - ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - memory_flag = False - if memory: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) - - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - streaming=True - ) - - # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - - # 流式调用 Agent - full_content = "" - async for chunk in agent.chat_stream( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag + # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) + async for event in workflow_service._run_workflow_stream( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id=str(workspace_id), + user_id=user_id ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + # 直接转发 executor 的事件(已经是正确的格式) + yield event - elapsed_time = time.time() - start_time - - # 保存消息 - self.conversation_service.add_message( - conversation_id=conversation_id, - role="user", - content=message - ) - - self.conversation_service.add_message( - conversation_id=conversation_id, - role="assistant", - content=full_content, - meta_data={ - "model": api_key_obj.model_name, - "usage": {} - } - ) - - # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} - yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - - logger.info( - "流式聊天完成", - extra={ - "conversation_id": str(conversation_id), - "elapsed_time": elapsed_time, - "message_length": len(full_content) - } - ) - - except (GeneratorExit, asyncio.CancelledError): - # 生成器被关闭或任务被取消,正常退出 - logger.debug("流式聊天被中断") - raise except Exception as e: - logger.error(f"流式聊天失败: {str(e)}", exc_info=True) + logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + workflow_service.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield { + "event": "error", + "data": { + "execution_id": execution.execution_id, + "error": str(e) + } + } + # ==================== 依赖注入函数 ==================== def get_app_chat_service( diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 95bcc07a..38097c4e 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -21,6 +21,7 @@ from app.core.exceptions import ( BusinessException, ) from app.core.logging_config import get_business_logger +from app.core.workflow.validator import WorkflowValidator from app.db import get_db from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig from app.models.app_model import AppStatus, AppType @@ -31,6 +32,7 @@ from app.schemas.workflow_schema import WorkflowConfigUpdate from app.services.agent_config_converter import AgentConfigConverter from app.models import AppShare, Workspace from app.services.model_service import ModelApiKeyService +from app.services.workflow_service import WorkflowService # 获取业务日志器 logger = get_business_logger() @@ -1225,6 +1227,26 @@ class AppService: "orchestration_mode": multi_agent_cfg.orchestration_mode } ) + elif app.type == AppType.WORKFLOW: + service = WorkflowService(self.db) + workflow_cfg = service.get_workflow_config(app_id) + if not workflow_cfg: + raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING) + + config = { + "nodes": workflow_cfg.nodes, + "edges": workflow_cfg.edges, + "variables": workflow_cfg.variables, + "execution_config": workflow_cfg.execution_config, + "triggers": workflow_cfg.triggers + } + + is_valid, errors = WorkflowValidator.validate_for_publish(config) + if not is_valid: + raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING) + logger.info( + "应用发布配置准备完成" + ) now = datetime.datetime.now() version = self._get_next_version(app_id) diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 08ae7e57..fd3ce229 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -1293,6 +1293,7 @@ class MultiAgentOrchestrator: conversation_id: 会话 ID user_id: 用户 ID + Returns: 执行结果 """ diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 482e8213..b3ac1b79 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -231,9 +231,9 @@ class PromptOptimizerService: if m: prompt_index = m.start() prompt_finished = True - yield {"type": "delta", "content": buffer[idx:prompt_index]} + yield {"content": buffer[idx:prompt_index]} else: - yield {"type": "delta", "content": cache[idx:]} + yield {"content": cache[idx:]} if len(cache) != 0: idx = len(cache) @@ -249,8 +249,8 @@ class PromptOptimizerService: role=RoleType.ASSISTANT, content=desc ) - - yield {"type": "done", "desc": optim_result.get("desc")} + variables = self.parser_prompt_variables(optim_result.get("prompt")) + yield {"desc": optim_result.get("desc"), "variables": variables} @staticmethod def parser_prompt_variables(prompt: str): diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 50cca957..ab5128fd 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -344,14 +344,16 @@ class ToolService: break if operation_param: - # 有多个操作 + # 有多个操作,为每个操作生成具体参数 methods = [] for operation in operation_param.enum: + # 获取该操作的具体参数 + operation_params = self._get_operation_specific_params(tool_instance, operation) methods.append({ "method_id": f"{config.name}_{operation}", "name": operation, "description": f"{config.description} - {operation}", - "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + "parameters": operation_params }) return methods else: @@ -362,6 +364,243 @@ class ToolService: "description": config.description, "parameters": [p for p in tool_instance.parameters if p.name != "operation"] }] + + def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]: + """获取特定操作的参数列表""" + # 对于datetime_tool,根据操作类型返回相关参数 + if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool': + return self._get_datetime_tool_params(operation) + # 对于json_tool,根据操作类型返回相关参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': + return self._get_json_tool_params(operation) + + # 其他工具的默认处理:返回除operation外的所有参数 + return [{ + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } for param in tool_instance.parameters if param.name != "operation"] + + def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取datetime_tool特定操作的参数""" + if operation == "now": + return [ + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "format": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "convert_timezone": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + elif operation == "timestamp_to_datetime": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + else: + # 默认返回所有参数(除了operation) + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": False + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "calculation", + "type": "string", + "description": "时间计算表达式(如:+1d, -2h, +30m)", + "required": False + } + ] + + def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取json_tool特定操作的参数""" + base_params = [ + { + "name": "input_data", + "type": "string", + "description": "输入数据(JSON字符串、YAML字符串或XML字符串)", + "required": True + } + ] + + if operation == "insert": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "new_value", + "type": "string", + "description": "新值(用于insert操作)", + "required": True + } + ] + elif operation == "replace": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "old_text", + "type": "string", + "description": "要替换的原文本(用于replace操作)", + "required": True + }, + { + "name": "new_text", + "type": "string", + "description": "替换后的新文本(用于replace操作)", + "required": True + } + ] + elif operation == "delete": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + elif operation == "parse": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + + return base_params async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法""" diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index bf0375fb..40851835 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -7,7 +7,6 @@ User Memory Service import os import uuid from collections import Counter -from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Optional, Tuple @@ -22,7 +21,269 @@ from sqlalchemy.orm import Session logger = get_logger(__name__) -# Neo4j connector instan +# Neo4j connector instance for analytics functions +_neo4j_connector = Neo4jConnector() + +# Default LLM ID for fallback +DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") + + +# ============================================================================ +# Internal Helper Classes +# ============================================================================ + +class TagClassification(BaseModel): + """Represents the classification of a tag into a specific domain.""" + domain: str = Field( + ..., + description="The domain the tag belongs to, chosen from the predefined list.", + examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"], + ) + + +def _get_llm_client_for_user(user_id: str): + """ + Get LLM client for a specific user based on their config. + + Args: + user_id: User ID to get config for + + Returns: + LLM client instance + """ + with get_db_context() as db: + try: + from app.services.memory_agent_service import get_end_user_connected_config + connected_config = get_end_user_connected_config(user_id, db) + config_id = connected_config.get("memory_config_id") + + if config_id: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + factory = MemoryClientFactory(db) + return factory.get_llm_client(memory_config.llm_model_id) + else: + factory = MemoryClientFactory(db) + return factory.get_llm_client(DEFAULT_LLM_ID) + except Exception as e: + logger.warning(f"Failed to get user connected config, using default LLM: {e}") + factory = MemoryClientFactory(db) + return factory.get_llm_client(DEFAULT_LLM_ID) + + +class MemoryInsightHelper: + """ + Internal helper class for memory insight analysis. + Provides basic data retrieval and analysis functionality. + """ + + def __init__(self, user_id: str): + self.user_id = user_id + self.neo4j_connector = Neo4jConnector() + self.llm_client = _get_llm_client_for_user(user_id) + + async def close(self): + """Close database connection.""" + await self.neo4j_connector.close() + + async def get_domain_distribution(self) -> dict[str, float]: + """Calculate the distribution of memory domains based on hot tags.""" + from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags + + hot_tags = await get_hot_memory_tags(self.user_id) + if not hot_tags: + return {} + + domain_counts = Counter() + for tag, _ in hot_tags: + prompt = f"""请将以下标签归类到最合适的领域中。 + +可选领域及其关键词: +- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等 +- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等 +- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等 +- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等 +- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等 +- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等 +- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等 +- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等 +- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等 +- 其他:确实无法归入以上任何类别的内容 + +标签: {tag} + +分析步骤: +1. 仔细理解标签的核心含义和使用场景 +2. 对比各个领域的关键词,找到最匹配的领域 +3. 特别注意: + - 历史、科学、文化等知识性内容应归类为"学习" + - 学校、课程、考试等正式教育场景应归类为"教育" + - 只有在标签完全不属于上述9个具体领域时,才选择"其他" +4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他" + +请直接返回最合适的领域名称。""" + messages = [ + {"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"}, + {"role": "user", "content": prompt} + ] + classification = await self.llm_client.response_structured( + messages=messages, + response_model=TagClassification, + ) + if classification and hasattr(classification, 'domain') and classification.domain: + domain_counts[classification.domain] += 1 + + total_tags = sum(domain_counts.values()) + if total_tags == 0: + return {} + + domain_distribution = { + domain: count / total_tags for domain, count in domain_counts.items() + } + return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)) + + async def get_active_periods(self) -> list[int]: + """ + Identify the top 2 most active months for the user. + Only returns months if there is valid and diverse time data. + """ + query = """ + MATCH (d:Dialogue) + WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> '' + RETURN d.created_at AS creation_time + """ + records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + + if not records: + return [] + + month_counts = Counter() + valid_dates_count = 0 + for record in records: + creation_time_str = record.get("creation_time") + if not creation_time_str: + continue + try: + dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00")) + month_counts[dt_object.month] += 1 + valid_dates_count += 1 + except (ValueError, TypeError, AttributeError): + continue + + if not month_counts or valid_dates_count == 0: + return [] + + # Check if time distribution is too concentrated (likely batch imported data) + unique_months = len(month_counts) + if unique_months <= 2: + most_common_count = month_counts.most_common(1)[0][1] + if most_common_count / valid_dates_count > 0.8: + return [] + + if unique_months >= 3: + most_common_months = month_counts.most_common(2) + return [month for month, _ in most_common_months] + + if unique_months == 2: + counts = list(month_counts.values()) + ratio = min(counts) / max(counts) + if ratio > 0.3: + most_common_months = month_counts.most_common(2) + return [month for month, _ in most_common_months] + + return [] + + async def get_social_connections(self) -> dict | None: + """Find the user with whom the most memories are shared.""" + query = """ + MATCH (c1:Chunk {group_id: $group_id}) + OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) + OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) + WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL + WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements + WHERE common_statements > 0 + RETURN other_user_id, common_statements + ORDER BY common_statements DESC + LIMIT 1 + """ + records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + if not records or not records[0].get("other_user_id"): + return None + + most_connected_user = records[0]["other_user_id"] + common_memories_count = records[0]["common_statements"] + + time_range_query = """ + MATCH (c:Chunk) + WHERE c.group_id IN [$user_id, $other_user_id] + RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time + """ + time_records = await self.neo4j_connector.execute_query( + time_range_query, + user_id=self.user_id, + other_user_id=most_connected_user + ) + start_year, end_year = "N/A", "N/A" + if time_records and time_records[0]["start_time"]: + start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year + end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year + + return { + "user_id": most_connected_user, + "common_memories_count": common_memories_count, + "time_range": f"{start_year}-{end_year}", + } + + +class UserSummaryHelper: + """ + Internal helper class for user summary generation. + Provides data retrieval functionality for user summary analysis. + """ + + def __init__(self, user_id: str): + self.user_id = user_id + self.connector = Neo4jConnector() + self.llm = _get_llm_client_for_user(user_id) + + async def close(self): + """Close database connection.""" + await self.connector.close() + + async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]: + """Fetch recent statements authored by the user/group for context.""" + query = ( + "MATCH (s:Statement) " + "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " + "RETURN s.statement AS statement, s.created_at AS created_at " + "ORDER BY created_at DESC LIMIT $limit" + ) + rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) + records = [] + for r in rows: + try: + records.append({ + "statement": r.get("statement", ""), + "created_at": r.get("created_at") + }) + except Exception: + continue + return records + + async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]: + """Get meaningful entities and their frequencies using hot tag logic.""" + from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags + return await get_hot_memory_tags(self.user_id, limit=limit) + + +# ============================================================================ +# Service Class +# ============================================================================ + + +# ============================================================================ +# Service Class +# ============================================================================ class UserMemoryService: @@ -601,7 +862,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> 生成记忆洞察报告(四个维度) 这个函数包含完整的业务逻辑: - 1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联) + 1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联) 2. 使用 Jinja2 模板渲染提示词 3. 调用 LLM 生成四个维度的自然语言报告 4. 解析并返回四个部分 @@ -620,7 +881,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt import re - insight = MemoryInsight(end_user_id) + insight = MemoryInsightHelper(end_user_id) try: # 1. 并行获取三个维度的数据 @@ -722,7 +983,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, 生成用户摘要(包含四个部分) 这个函数包含完整的业务逻辑: - 1. 使用 UserSummary 工具类获取基础数据(实体、语句) + 1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句) 2. 使用 prompt_utils 渲染提示词 3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结 @@ -737,20 +998,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, "one_sentence": str } """ - from app.core.memory.analytics.user_summary import UserSummary from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt import re - # 创建 UserSummary 实例 - user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) + # 创建 UserSummaryHelper 实例 + user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) try: # 1) 收集上下文数据 - entities = await user_summary_tool._get_top_entities(limit=40) - statements = await user_summary_tool._get_recent_statements(limit=100) + entities = await user_summary_tool.get_top_entities(limit=40) + statements = await user_summary_tool.get_recent_statements(limit=100) entity_lines = [f"{name} ({freq})" for name, freq in entities][:20] - statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20] + statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20] # 2) 使用 prompt_utils 渲染提示词 user_prompt = await render_user_summary_prompt( @@ -794,6 +1054,28 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, core_values = core_values_match.group(1).strip() if core_values_match else "" one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else "" + # 6) 清理可能包含的反思内容(防御性编程) + # 如果 LLM 仍然输出了反思内容,在这里过滤掉 + def clean_reflection_content(text: str) -> str: + """移除可能包含的反思内容""" + if not text: + return text + # 移除 "---" 之后的所有内容(通常是反思部分的开始) + if '---' in text: + text = text.split('---')[0].strip() + # 移除 "**Step" 开头的内容 + if '**Step' in text: + text = text.split('**Step')[0].strip() + # 移除 "Self-Review" 相关内容 + if 'Self-Review' in text or 'self-review' in text: + text = re.sub(r'[\-\*]*\s*Self-Review.*$', '', text, flags=re.IGNORECASE | re.DOTALL).strip() + return text + + user_summary = clean_reflection_content(user_summary) + personality = clean_reflection_content(personality) + core_values = clean_reflection_content(core_values) + one_sentence = clean_reflection_content(one_sentence) + return { "user_summary": user_summary, "personality": personality, diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d96efdf7..68d6279b 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -17,6 +17,7 @@ from app.core.workflow.validator import validate_workflow_config from app.db import get_db, get_db_context from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.end_user_repository import EndUserRepository +from app.services.multi_agent_service import convert_uuids_to_str from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -364,7 +365,7 @@ class WorkflowService: execution.status = status if output_data is not None: - execution.output_data = output_data + execution.output_data = convert_uuids_to_str(output_data) if error_message is not None: execution.error_message = error_message if error_node_id is not None: diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 4fe692c1..97e64214 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -8,7 +8,7 @@ import uuid from typing import Dict, Any, Optional from datetime import datetime -from app.models import AppRelease +from app.models import AppRelease, WorkflowConfig from app.models.agent_app_config_model import AgentConfig from app.models.multi_agent_model import MultiAgentConfig @@ -28,7 +28,7 @@ class AgentConfigProxy: def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: config_dict = release.config - + agent_config = AgentConfig( app_id=release.app_id, system_prompt=config_dict.get("system_prompt"), @@ -45,10 +45,10 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig: def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: config_dict = release.config - + agent_config = MultiAgentConfig( - app_id=release.app_id, + app_id=release.app_id, default_model_config_id=release.default_model_config_id, model_parameters=config_dict.get("model_parameters"), master_agent_id=config_dict.get("master_agent_id"), @@ -58,11 +58,29 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig: routing_rules=config_dict.get("routing_rules"), execution_config=config_dict.get("execution_config", {}), aggregation_strategy=config_dict.get("aggregation_strategy", "merge"), - + ) return agent_config +def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig: + + config_dict = release.config + + + config = WorkflowConfig( + id=release.id, + app_id=release.app_id, + nodes=config_dict.get("nodes", []), + edges=config_dict.get("edges", []), + variables=config_dict.get("variables", []), + execution_config=config_dict.get("execution_config", {}), + triggers=config_dict.get("triggers", []) + + ) + + return config + def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None): """Convert dict to MultiAgentConfig model object diff --git a/web/src/api/prompt.ts b/web/src/api/prompt.ts index 77ea1271..526f50ac 100644 --- a/web/src/api/prompt.ts +++ b/web/src/api/prompt.ts @@ -1,5 +1,6 @@ import { request } from '@/utils/request' import type { AiPromptForm } from '@/views/ApplicationConfig/types' +import { handleSSE, type SSEMessage } from '@/utils/stream' export const createPromptSessions = () => { return request.post(`/prompt/sessions`) @@ -7,6 +8,6 @@ export const createPromptSessions = () => { export const getPrompt = (session_id: string) => { return request.get(`/prompt/sessions/${session_id}`) } -export const updatePromptMessages = (session_id: string, data: AiPromptForm) => { - return request.post(`/prompt/sessions/${session_id}/messages`, data) +export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => { + return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage) } \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory-read.png b/web/src/assets/images/workflow/memory-read.png new file mode 100644 index 00000000..4b0cdc1d Binary files /dev/null and b/web/src/assets/images/workflow/memory-read.png differ diff --git a/web/src/assets/images/workflow/memory-write.png b/web/src/assets/images/workflow/memory-write.png new file mode 100644 index 00000000..83a50fd4 Binary files /dev/null and b/web/src/assets/images/workflow/memory-write.png differ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index b53ff2bc..a96d986c 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1224,6 +1224,8 @@ export const en = { key_findings: 'Key Findings', behavior_pattern: 'Behavior Pattern', growth_trajectory: 'Growth Trajectory', + personality: 'Personality Traits', + core_values: 'Core Values', }, space: { createSpace: 'Create Space', @@ -1799,12 +1801,20 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re "not_contains": 'Does Not Contain', "startwith": 'Starts With', "endwith": 'Ends With', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": 'Equals', + "ne": 'Not Equals', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": 'Is', + "ne": 'Is Not', + }, else_desc: 'Used to define the logic that should be executed when the if condition is not met.' }, 'http-request': { @@ -1845,12 +1855,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re loop: { cycle_vars: 'Loop Variables', condition: 'Loop Termination Condition', + max_loop: 'Maximum Loop Count', }, assigner: { assignments: 'Variables', - cover: 'Overwrite', + cover: 'Override', assign: 'Set', - clear: 'Clear' + clear: 'Clear', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: 'Input Variable', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 141d4600..2f38cf8e 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1305,6 +1305,8 @@ export const zh = { key_findings: '关键发现', behavior_pattern: '行为模式', growth_trajectory: '成长轨迹', + personality: '性格特点', + core_values: '核心价值观', }, space: { createSpace: '创建空间', @@ -1899,12 +1901,20 @@ export const zh = { "not_contains": '不包含', "startwith": '开始是', "endwith": '结束是', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": '是', + "ne": '不是', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": '是', + "ne": '不是', + }, else_desc: '用于定义当 if 条件不满足时应执行的逻辑。' }, 'http-request': { @@ -1945,12 +1955,17 @@ export const zh = { loop: { cycle_vars: '循环变量', condition: '循环终止条件', + max_loop: '最大循环次数', }, assigner: { assignments: '变量', cover: '覆盖', assign: '设置', - clear: '清空' + clear: '清空', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: '输入变量', diff --git a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx index a85f5cf1..f52c0675 100644 --- a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx +++ b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx @@ -16,6 +16,8 @@ import ConversationEmptyIcon from '@/assets/images/conversation/conversationEmpt import type { ChatItem } from '@/components/Chat/types' import CustomSelect from '@/components/CustomSelect' import AiPromptVariableModal from './AiPromptVariableModal' +import { type SSEMessage } from '@/utils/stream' +import Editor from './Editor' interface AiPromptModalProps { refresh: (value: string) => void; @@ -35,7 +37,8 @@ const AiPromptModal = forwardRef(({ const [variables, setVariables] = useState([]) const [promptSession, setPromptSession] = useState(null) const aiPromptVariableModalRef = useRef(null) - const currentPromptRef = useRef(null) + const editorRef = useRef(null) + const currentPromptValueRef = useRef('') const values = Form.useWatch([], form) @@ -78,16 +81,45 @@ const AiPromptModal = forwardRef(({ setChatList(prev => { return [...prev, { role: 'user', content: messageContent}] }) - form.setFieldsValue({ message: undefined }) - updatePromptMessages(promptSession, values) - .then(res => { - const response = res as { prompt: string; desc: string; variables: string[] } - form.setFieldsValue({ current_prompt: response.prompt }) - setChatList(prev => { - return [...prev, { role: 'assistant', content: response.desc }] - }) - setVariables(response.variables) + form.setFieldsValue({ message: undefined, current_prompt: undefined }) + + const handleStreamMessage = (data: SSEMessage[]) => { + data.map(item => { + const { content, desc, variables } = item.data as { content: string; desc: string; variables: string[] }; + + switch (item.event) { + case 'start': + currentPromptValueRef.current = '' + break; + case 'message': + if (content) { + currentPromptValueRef.current += content; + form.setFieldsValue({ current_prompt: currentPromptValueRef.current }) + } + if (desc) { + setChatList(prev => { + return [...prev, { role: 'assistant', content: desc }] + }) + } + if (variables) { + setVariables(variables) + } + break; + case 'end': + setLoading(false) + break + } }) + }; + updatePromptMessages(promptSession, values, handleStreamMessage) + // .then(res => { + // const response = res as { prompt: string; desc: string; variables: string[] } + // form.setFieldsValue({ current_prompt: response.prompt }) + // setChatList(prev => { + // return [...prev, { role: 'assistant', content: response.desc }] + // }) + // setVariables(response.variables) + // }) .finally(() => { setLoading(false) }) @@ -101,18 +133,8 @@ const AiPromptModal = forwardRef(({ aiPromptVariableModalRef.current?.handleOpen() } const handleVariableApply = (value: string) => { - const textArea = currentPromptRef.current?.resizableTextArea?.textArea - if (textArea) { - const cursorPosition = textArea.selectionStart - const currentValue = values.current_prompt || '' - const newValue = currentValue.slice(0, cursorPosition) + value + currentValue.slice(cursorPosition) - form.setFieldValue('current_prompt', newValue) - - // 设置新的光标位置 - setTimeout(() => { - textArea.focus() - textArea.setSelectionRange(cursorPosition + value.length, cursorPosition + value.length) - }, 0) + if (editorRef.current?.insertText) { + editorRef.current.insertText(value) } else { form.setFieldValue('current_prompt', (values.current_prompt || '') + value) } @@ -191,7 +213,11 @@ const AiPromptModal = forwardRef(({ - + form.setFieldValue('current_prompt', value)} + />
diff --git a/web/src/views/ApplicationConfig/components/Editor/index.tsx b/web/src/views/ApplicationConfig/components/Editor/index.tsx new file mode 100644 index 00000000..d381e003 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/index.tsx @@ -0,0 +1,91 @@ +import {forwardRef, useImperativeHandle } from 'react'; +import clsx from 'clsx'; +import { LexicalComposer } from '@lexical/react/LexicalComposer'; +import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'; +import { ContentEditable } from '@lexical/react/LexicalContentEditable'; +import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'; +import { $getSelection } from 'lexical'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import InitialValuePlugin from './plugin/InitialValuePlugin' +import LineBreakPlugin from './plugin/LineBreakPlugin'; +import InsertTextPlugin from './plugin/InsertTextPlugin'; + +export interface EditorRef { + insertText: (text: string) => void; +} + +interface LexicalEditorProps { + className?: string; + placeholder?: string; + value?: string; + onChange?: (value: string) => void; + height?: number; +} + +const theme = { + paragraph: 'editor-paragraph', + text: { + bold: 'editor-text-bold', + italic: 'editor-text-italic', + }, +}; + +const EditorContent = forwardRef(({ + className = '', + value, + placeholder = "请输入内容...", + onChange, +}, ref) => { + const [editor] = useLexicalComposerContext(); + + useImperativeHandle(ref, () => ({ + insertText: (text: string) => { + editor.update(() => { + const selection = $getSelection(); + if (selection) { + selection.insertText(text); + } + }); + } + }), [editor]); + + return ( +
+ + } + placeholder={ +
+ {placeholder} +
+ } + ErrorBoundary={LexicalErrorBoundary} + /> + + + +
+ ); +}); + +const Editor = forwardRef((props, ref) => { + const initialConfig = { + namespace: 'Editor', + theme, + nodes: [], + onError: (error: Error) => { + console.error(error); + }, + }; + + return ( + + + + ); +}); + +export default Editor; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx new file mode 100644 index 00000000..b1054055 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx @@ -0,0 +1,25 @@ +import { type FC, useEffect } from 'react'; +import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; + +// 设置初始值的插件 +const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => { + const [editor] = useLexicalComposerContext(); + + useEffect(() => { + if (value) { + editor.update(() => { + const root = $getRoot(); + root.clear(); + const paragraph = $createParagraphNode(); + const textNode = $createTextNode(value); + paragraph.append(textNode); + root.append(paragraph); + }); + } + }, [editor, value]); + + return null; +}; + +export default InitialValuePlugin \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx new file mode 100644 index 00000000..ca75c393 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx @@ -0,0 +1,24 @@ +import { forwardRef, useImperativeHandle } from 'react'; +import { $getSelection } from 'lexical'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import type { EditorRef } from '../index' + +// 插入文本的插件 +const InsertTextPlugin = forwardRef((_, ref) => { + const [editor] = useLexicalComposerContext(); + + useImperativeHandle(ref, () => ({ + insertText: (text: string) => { + editor.update(() => { + const selection = $getSelection(); + if (selection) { + selection.insertText(text); + } + }); + } + }), [editor]); + + return null; +}); + +export default InsertTextPlugin; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx new file mode 100644 index 00000000..63d1ffc4 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx @@ -0,0 +1,24 @@ +import { type FC, useEffect } from 'react'; +import { $getRoot } from 'lexical'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; + +// 处理换行的插件 +const LineBreakPlugin: FC<{ onChange?: (value: string) => void }> = ({ onChange }) => { + const [editor] = useLexicalComposerContext(); + + useEffect(() => { + return editor.registerUpdateListener(({ editorState }) => { + editorState.read(() => { + const root = $getRoot(); + const textContent = root.getTextContent(); + // 将\n转换为实际换行 + const processedContent = textContent.replace(/\\n/g, '\n'); + onChange?.(processedContent); + }); + }); + }, [editor, onChange]); + + return null; +}; + +export default LineBreakPlugin; \ No newline at end of file diff --git a/web/src/views/ToolManagement/Inner.tsx b/web/src/views/ToolManagement/Inner.tsx index d256d6c7..6f85e1f7 100644 --- a/web/src/views/ToolManagement/Inner.tsx +++ b/web/src/views/ToolManagement/Inner.tsx @@ -4,10 +4,9 @@ import { Col, Tag, List, - Space + Flex } from 'antd'; import { EyeOutlined } from '@ant-design/icons'; -import clsx from 'clsx' import { useTranslation } from 'react-i18next'; import dayjs, { type Dayjs } from 'dayjs' @@ -103,9 +102,9 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getS
{t(`tool.${item.config_data.tool_class}_features`)}
- + {InnerConfigData[item.config_data.tool_class].features.map(vo => { t(`tool.${vo}`) }) } - + {item.config_data.tool_class === 'DateTimeTool' ?
diff --git a/web/src/views/UserMemoryDetail/components/AboutMe.tsx b/web/src/views/UserMemoryDetail/components/AboutMe.tsx index ba7e68fe..f2c94814 100644 --- a/web/src/views/UserMemoryDetail/components/AboutMe.tsx +++ b/web/src/views/UserMemoryDetail/components/AboutMe.tsx @@ -5,16 +5,25 @@ import { Skeleton } from 'antd'; import RbCard from '@/components/RbCard/Card' import Empty from '@/components/Empty'; +import RbAlert from '@/components/RbAlert'; import { getUserSummary, } from '@/api/memory' import type { AboutMeRef } from '../types' + +interface Data { + user_summary: string; + personality: string; + core_values: string; + one_sentence: string; + [key: string]: string; +} const AboutMe = forwardRef((_props, ref) => { const { t } = useTranslation() const { id } = useParams() const [loading, setLoading] = useState(false) - const [data, setData] = useState(null) + const [data, setData] = useState({} as Data) useEffect(() => { if (!id) return @@ -27,7 +36,7 @@ const AboutMe = forwardRef((_props, ref) => { setLoading(true) getUserSummary(id) .then((res) => { - setData((res as { summary?: string }).summary || null) + setData((res as Data) || null) }) .finally(() => { setLoading(false) @@ -44,10 +53,29 @@ const AboutMe = forwardRef((_props, ref) => { > {loading ? - : data - ?
- {data || '-'} -
+ : Object.keys(data).filter(key => data[key] !== null).length > 0 + ? <> + {data.user_summary && +
+ {data.user_summary} +
+ } + {data.personality && <> +
{t('userMemory.personality')}
+
+ {data.personality} +
+ } + {data.core_values && <> +
{t('userMemory.core_values')}
+
+ {data.core_values} +
+ } + {data.one_sentence && + {data.one_sentence} + } + : } diff --git a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx index 571f1e4e..fabe45ba 100644 --- a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx @@ -26,7 +26,6 @@ const ChatVariableModal = forwardRef(); const [loading, setLoading] = useState(false) const [editIndex, setEditIndex] = useState(undefined) - const typeValue = Form.useWatch('type', form); // 封装取消方法,添加关闭弹窗逻辑 const handleClose = () => { diff --git a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx index 963f824b..ed07392d 100644 --- a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx @@ -14,18 +14,23 @@ const CharacterCountPlugin = ({ setCount, onChange }: { setCount: (count: number let serializedContent = ''; // Traverse all nodes and serialize properly + const paragraphs: string[] = []; root.getChildren().forEach(child => { if ($isParagraphNode(child)) { + let paragraphContent = ''; child.getChildren().forEach(node => { if ($isVariableNode(node)) { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } else { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } }); + paragraphs.push(paragraphContent); } }); + serializedContent = paragraphs.join('\n'); + setCount(serializedContent.length); onChange?.(serializedContent); }); diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 4059b300..93197150 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -26,6 +26,7 @@ const InitialValuePlugin: React.FC = ({ value, options parts.forEach(part => { const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/); const contextMatch = part.match(/^\{\{context\}\}$/); + const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/); // 匹配{{context}}格式 if (contextMatch) { @@ -38,6 +39,20 @@ const InitialValuePlugin: React.FC = ({ value, options return } + // 匹配{{conv.xx}}格式 + if (conversationMatch) { + const [_, variableName] = conversationMatch; + const conversationSuggestion = options.find(s => + s.group === 'CONVERSATION' && s.label === variableName + ); + if (conversationSuggestion) { + paragraph.append($createVariableNode(conversationSuggestion)); + } else { + paragraph.append($createTextNode(part)); + } + return + } + // 匹配普通变量{{nodeId.label}}格式 if (match) { const [_, nodeId, label] = match; diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index a2f6d930..973a503c 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -13,13 +13,15 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { const handleNodeSelect = (selectedNodeType: any) => { const parentBBox = node.getBBox(); const cycleId = data.cycle; - + + const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: parentBBox.x, y: parentBBox.y, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index b0b8d4ce..37feb2dc 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -75,12 +75,15 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { const parentBBox = node.getBBox(); const centerX = parentBBox.x + 24; // 默认节点宽度的一半 const centerY = parentBBox.y + 50; // 默认节点高度的一半 - + + const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const cycleStartNode = graph.addNode({ ...graphNodeLibrary.cycleStart, x: centerX, y: centerY, + id: cycleStartNodeId, data: { + id: cycleStartNodeId, type: 'cycle-start', parentId: node.id, isDefault: true, // 标记为默认节点,不可删除 diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 0be6fba1..9a644438 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -43,12 +43,14 @@ const PortClickHandler: React.FC = ({ graph }) => { const newY = sourceBBox.y; // 创建新节点 + const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: newX, y: newY, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx index 34c133c7..eac3775f 100644 --- a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx +++ b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx @@ -1,6 +1,6 @@ import { type FC } from 'react' import { useTranslation } from 'react-i18next'; -import { Form, Input, Button, Row, Col, Select } from 'antd' +import { Form, Input, Row, Col, Select, InputNumber, Radio } from 'antd' import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import VariableSelect from '../VariableSelect' @@ -11,6 +11,23 @@ interface AssignmentListProps { options: Suggestion[]; } +const operationsObj = { + number: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + { value: 'add', label: 'workflow.config.assigner.add' }, + { value: 'subtract', label: 'workflow.config.assigner.subtract' }, + { value: 'multiply', label: 'workflow.config.assigner.multiply' }, + { value: 'divide', label: 'workflow.config.assigner.divide' }, + ], + default: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + ], +} + const AssignmentList: FC = ({ parentName, options = [], @@ -27,6 +44,11 @@ const AssignmentList: FC = ({ add({ operation: 'cover'})} />
{fields.map(({ key, name, ...restField }) => { + const variableSelector = form.getFieldValue([parentName, name, 'variable_selector']); + const selectedOption = options.find(option => `{{${option.value}}}` === variableSelector); + const dataType = selectedOption?.dataType; + const operationOptions = dataType === 'number' ? operationsObj.number : operationsObj.default; + return (
@@ -50,11 +72,10 @@ const AssignmentList: FC = ({ noStyle > ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} + placeholder={t('common.pleaseSelect')} /> @@ -280,11 +321,48 @@ const CaseList: FC = ({ - {!hideRightField && ( - - - - )} + {!hideRightField && <> + {leftFieldType === 'number' + ? + + + ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} @@ -104,14 +139,53 @@ const ConditionList: FC = ({ onClick={() => remove(field.name)} /> - - {!hideRightField && ( - - - - - - )} + + {!hideRightField && <> + {leftFieldType === 'number' + ? + + + { - console.log('value record', value) - handleChange(record.key, 'type', value) - }} - /> - ), - }, - { - title: t('workflow.config.value'), + const columns = useMemo(() => { + const baseColumns = [ + { + title: typeOptions.length > 0 ? t('workflow.config.name') : '键', + dataIndex: 'name', + width: typeOptions.length > 0 ? '35%' : '45%', + render: (text: string, record: TableRow) => ( + handleChange(record.key, 'name', value || '')} + /> + ), + } + ]; + + if (typeOptions.length > 0) { + baseColumns.push({ + title: t('workflow.config.type'), + dataIndex: 'type', + width: '20%', + render: (text: string, record: TableRow) => ( + - - - - - remove(name)} /> - + + + + + + + + + + + + + remove(name)} /> + + ))}