Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts: # api/app/services/user_memory_service.py
This commit is contained in:
@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
|
|||||||
@router.post("/{app_id}/workflow")
|
@router.post("/{app_id}/workflow")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def create_workflow_config(
|
async def create_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
config: WorkflowConfigCreate,
|
config: WorkflowConfigCreate,
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""创建工作流配置
|
"""创建工作流配置
|
||||||
|
|
||||||
@@ -96,6 +96,7 @@ async def create_workflow_config(
|
|||||||
msg=f"创建工作流配置失败: {str(e)}"
|
msg=f"创建工作流配置失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# @router.get("/{app_id}/workflow")
|
# @router.get("/{app_id}/workflow")
|
||||||
# async def get_workflow_config(
|
# async def get_workflow_config(
|
||||||
@@ -199,10 +200,10 @@ async def create_workflow_config(
|
|||||||
|
|
||||||
@router.delete("/{app_id}/workflow")
|
@router.delete("/{app_id}/workflow")
|
||||||
async def delete_workflow_config(
|
async def delete_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""删除工作流配置
|
"""删除工作流配置
|
||||||
|
|
||||||
@@ -243,11 +244,11 @@ async def delete_workflow_config(
|
|||||||
|
|
||||||
@router.post("/{app_id}/workflow/validate")
|
@router.post("/{app_id}/workflow/validate")
|
||||||
async def validate_workflow_config(
|
async def validate_workflow_config(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||||
):
|
):
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
@@ -312,12 +313,12 @@ async def validate_workflow_config(
|
|||||||
|
|
||||||
@router.get("/{app_id}/workflow/executions")
|
@router.get("/{app_id}/workflow/executions")
|
||||||
async def get_workflow_executions(
|
async def get_workflow_executions(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
offset: Annotated[int, Query(ge=0)] = 0
|
offset: Annotated[int, Query(ge=0)] = 0
|
||||||
):
|
):
|
||||||
"""获取工作流执行记录列表
|
"""获取工作流执行记录列表
|
||||||
|
|
||||||
@@ -365,10 +366,10 @@ async def get_workflow_executions(
|
|||||||
|
|
||||||
@router.get("/workflow/executions/{execution_id}")
|
@router.get("/workflow/executions/{execution_id}")
|
||||||
async def get_workflow_execution(
|
async def get_workflow_execution(
|
||||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""获取工作流执行详情
|
"""获取工作流执行详情
|
||||||
|
|
||||||
@@ -417,16 +418,14 @@ async def get_workflow_execution(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
@router.post("/{app_id}/workflow/run")
|
@router.post("/{app_id}/workflow/run")
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
request: WorkflowExecutionRequest,
|
request: WorkflowExecutionRequest,
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""执行工作流
|
"""执行工作流
|
||||||
|
|
||||||
@@ -487,11 +486,11 @@ async def run_workflow(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async for event in await service.run_workflow(
|
async for event in await service.run_workflow(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
triggered_by=current_user.id,
|
triggered_by=current_user.id,
|
||||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
stream=True
|
stream=True
|
||||||
):
|
):
|
||||||
# 提取事件类型和数据
|
# 提取事件类型和数据
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
@@ -554,10 +553,10 @@ async def run_workflow(
|
|||||||
|
|
||||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||||
async def cancel_workflow_execution(
|
async def cancel_workflow_execution(
|
||||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
db: Annotated[Session, Depends(get_db)],
|
db: Annotated[Session, Depends(get_db)],
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
):
|
):
|
||||||
"""取消工作流执行
|
"""取消工作流执行
|
||||||
|
|
||||||
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
|
|||||||
|
|
||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||||
return fail(code=e.error_code, msg=e.message)
|
return fail(code=e.code, msg=e.message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||||
return fail(
|
return fail(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||||
# API Keys Configuration
|
# API Keys Configuration
|
||||||
@@ -142,7 +143,6 @@ class Settings:
|
|||||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
|
|
||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||||
@@ -150,7 +150,7 @@ class Settings:
|
|||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||||
@@ -168,6 +168,9 @@ class Settings:
|
|||||||
# official environment system version
|
# official environment system version
|
||||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||||
|
|
||||||
|
# workflow config
|
||||||
|
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||||
|
|
||||||
def get_memory_output_path(self, filename: str = "") -> str:
|
def get_memory_output_path(self, filename: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Get the full path for memory module output files.
|
Get the full path for memory module output files.
|
||||||
|
|||||||
@@ -425,15 +425,9 @@ async def Input_Summary(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Extract services from context
|
# Extract services from context
|
||||||
template_service = get_context_resource(ctx, "template_service")
|
|
||||||
session_service = get_context_resource(ctx, "session_service")
|
session_service = get_context_resource(ctx, "session_service")
|
||||||
search_service = get_context_resource(ctx, "search_service")
|
search_service = get_context_resource(ctx, "search_service")
|
||||||
|
|
||||||
# Get LLM client from memory_config
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
|
||||||
|
|
||||||
# Resolve session ID
|
# Resolve session ID
|
||||||
sessionid = Resolve_username(usermessages) or ""
|
sessionid = Resolve_username(usermessages) or ""
|
||||||
sessionid = sessionid.replace('call_id_', '')
|
sessionid = sessionid.replace('call_id_', '')
|
||||||
@@ -539,31 +533,11 @@ async def Input_Summary(
|
|||||||
)
|
)
|
||||||
retrieve_info, question, raw_results = "", query, []
|
retrieve_info, question, raw_results = "", query, []
|
||||||
|
|
||||||
|
# Return retrieved information directly without LLM processing
|
||||||
|
# Use the raw retrieved info as the answer
|
||||||
|
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||||
|
|
||||||
# Render template
|
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||||
system_prompt = await template_service.render_template(
|
|
||||||
template_name='Retrieve_Summary_prompt.jinja2',
|
|
||||||
operation_name='input_summary',
|
|
||||||
query=query,
|
|
||||||
history=history,
|
|
||||||
retrieve_info=retrieve_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call LLM with structured response
|
|
||||||
try:
|
|
||||||
structured = await llm_client.response_structured(
|
|
||||||
messages=[{"role": "system", "content": system_prompt}],
|
|
||||||
response_model=RetrieveSummaryResponse
|
|
||||||
)
|
|
||||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
aimessages = "信息不足,无法回答"
|
|
||||||
|
|
||||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
|
||||||
|
|
||||||
# Emit intermediate output for frontend
|
# Emit intermediate output for frontend
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
|
|||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
# 为了兼容性,创建别名
|
|
||||||
# SchemaParser = OpenAPISchemaParser = None
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAPISchemaParser:
|
class OpenAPISchemaParser:
|
||||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||||
@@ -214,6 +211,8 @@ class OpenAPISchemaParser:
|
|||||||
if not isinstance(operation, dict):
|
if not isinstance(operation, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
summary = operation.get("summary", "")
|
||||||
|
|
||||||
# 生成操作ID
|
# 生成操作ID
|
||||||
operation_id = operation.get("operationId")
|
operation_id = operation.get("operationId")
|
||||||
if not operation_id:
|
if not operation_id:
|
||||||
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
|
|||||||
operations[operation_id] = {
|
operations[operation_id] = {
|
||||||
"method": method.upper(),
|
"method": method.upper(),
|
||||||
"path": path,
|
"path": path,
|
||||||
"summary": operation.get("summary", ""),
|
"summary": summary if summary else operation_id,
|
||||||
"description": operation.get("description", ""),
|
"description": operation.get("description", ""),
|
||||||
"parameters": self._extract_parameters(operation),
|
"parameters": self._extract_parameters(operation),
|
||||||
"request_body": self._extract_request_body(operation),
|
"request_body": self._extract_request_body(operation),
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class LangchainAdapter:
|
|||||||
# 添加验证约束
|
# 添加验证约束
|
||||||
if param.enum:
|
if param.enum:
|
||||||
# 枚举值约束
|
# 枚举值约束
|
||||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||||
|
|
||||||
if param.minimum is not None:
|
if param.minimum is not None:
|
||||||
field_kwargs["ge"] = param.minimum
|
field_kwargs["ge"] = param.minimum
|
||||||
@@ -241,7 +241,7 @@ class LangchainAdapter:
|
|||||||
field_kwargs["le"] = param.maximum
|
field_kwargs["le"] = param.maximum
|
||||||
|
|
||||||
if param.pattern:
|
if param.pattern:
|
||||||
field_kwargs["regex"] = param.pattern
|
field_kwargs["pattern"] = param.pattern
|
||||||
|
|
||||||
fields[param.name] = Field(**field_kwargs)
|
fields[param.name] = Field(**field_kwargs)
|
||||||
annotations[param.name] = python_type
|
annotations[param.name] = python_type
|
||||||
|
|||||||
@@ -27,20 +27,22 @@ class SimpleMCPClient:
|
|||||||
|
|
||||||
# 确定连接类型
|
# 确定连接类型
|
||||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||||
|
self.is_sse = "/sse" in server_url.lower()
|
||||||
|
|
||||||
# 连接状态
|
# 连接状态
|
||||||
self._websocket = None
|
self._websocket = None
|
||||||
self._session = None
|
self._session = None
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._pending_requests = {}
|
self._pending_requests = {}
|
||||||
|
self._server_capabilities = {}
|
||||||
|
self._endpoint_url = None # SSE endpoint URL
|
||||||
|
self._sse_task = None
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""异步上下文管理器入口"""
|
|
||||||
await self.connect()
|
await self.connect()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""异步上下文管理器出口"""
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
@@ -57,47 +59,157 @@ class SimpleMCPClient:
|
|||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""断开连接"""
|
"""断开连接"""
|
||||||
try:
|
try:
|
||||||
|
if self._sse_task:
|
||||||
|
self._sse_task.cancel()
|
||||||
if self._websocket:
|
if self._websocket:
|
||||||
await self._websocket.close()
|
await self._websocket.close()
|
||||||
self._websocket = None
|
self._websocket = None
|
||||||
|
|
||||||
if self._session:
|
if self._session:
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"断开连接失败: {e}")
|
logger.error(f"断开连接失败: {e}")
|
||||||
|
|
||||||
async def _connect_websocket(self):
|
async def _connect_websocket(self):
|
||||||
"""WebSocket 连接"""
|
"""WebSocket 连接"""
|
||||||
headers = self._build_headers()
|
headers = self._build_headers()
|
||||||
|
|
||||||
self._websocket = await websockets.connect(
|
self._websocket = await websockets.connect(
|
||||||
self.server_url,
|
self.server_url,
|
||||||
extra_headers=headers,
|
extra_headers=headers,
|
||||||
timeout=self.timeout
|
timeout=self.timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
# 启动消息处理
|
|
||||||
asyncio.create_task(self._handle_websocket_messages())
|
asyncio.create_task(self._handle_websocket_messages())
|
||||||
|
|
||||||
# 发送初始化消息
|
|
||||||
await self._send_initialize()
|
await self._send_initialize()
|
||||||
|
|
||||||
async def _connect_http(self):
|
async def _connect_http(self):
|
||||||
"""HTTP 连接"""
|
"""HTTP 连接"""
|
||||||
headers = self._build_headers()
|
headers = self._build_headers()
|
||||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||||
|
|
||||||
self._session = aiohttp.ClientSession(
|
if self.is_sse:
|
||||||
headers=headers,
|
await self._initialize_sse_session()
|
||||||
timeout=timeout
|
elif "modelscope.net" in self.server_url:
|
||||||
)
|
|
||||||
|
|
||||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
|
||||||
if "modelscope.net" in self.server_url:
|
|
||||||
await self._initialize_modelscope_session()
|
await self._initialize_modelscope_session()
|
||||||
|
|
||||||
|
async def _initialize_sse_session(self):
|
||||||
|
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||||
|
try:
|
||||||
|
# 建立 SSE 连接
|
||||||
|
response = await self._session.get(
|
||||||
|
self.server_url,
|
||||||
|
headers={"Accept": "text/event-stream"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
|
# 启动 SSE 读取任务
|
||||||
|
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
|
||||||
|
|
||||||
|
# 等待获取 endpoint URL
|
||||||
|
for _ in range(10):
|
||||||
|
if self._endpoint_url:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if not self._endpoint_url:
|
||||||
|
raise MCPConnectionError("未能获取 endpoint URL")
|
||||||
|
|
||||||
|
# 发送 initialize 请求到 endpoint
|
||||||
|
init_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
init_response = await self._send_sse_request(init_request)
|
||||||
|
if "error" in init_response:
|
||||||
|
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||||
|
|
||||||
|
result = init_response.get("result", {})
|
||||||
|
self._server_capabilities = result.get("capabilities", {})
|
||||||
|
|
||||||
|
# 发送 initialized 通知
|
||||||
|
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||||
|
|
||||||
|
async def _read_sse_stream(self, response):
|
||||||
|
"""读取 SSE 流"""
|
||||||
|
try:
|
||||||
|
async for line in response.content:
|
||||||
|
line = line.decode('utf-8').strip()
|
||||||
|
|
||||||
|
if line.startswith('event:'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith('data:'):
|
||||||
|
data = line[5:].strip() # 去除 'data:' 后的空格
|
||||||
|
if not data or data == '[DONE]':
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 处理 endpoint 事件(相对路径或绝对路径)
|
||||||
|
if not self._endpoint_url:
|
||||||
|
# 如果是相对路径,拼接成完整 URL
|
||||||
|
if data.startswith('/'):
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
parsed = urlparse(self.server_url)
|
||||||
|
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
|
||||||
|
else:
|
||||||
|
self._endpoint_url = data
|
||||||
|
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理 message 事件
|
||||||
|
message = json.loads(data)
|
||||||
|
request_id = message.get("id")
|
||||||
|
if request_id and request_id in self._pending_requests:
|
||||||
|
future = self._pending_requests.pop(request_id)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(message)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SSE 流读取错误: {e}")
|
||||||
|
|
||||||
|
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""通过 SSE endpoint 发送请求"""
|
||||||
|
if not self._endpoint_url:
|
||||||
|
raise MCPConnectionError("endpoint URL 未初始化")
|
||||||
|
|
||||||
|
request_id = request["id"]
|
||||||
|
future = asyncio.Future()
|
||||||
|
self._pending_requests[request_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
|
return await asyncio.wait_for(future, timeout=self.timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_requests.pop(request_id, None)
|
||||||
|
raise MCPConnectionError("请求超时")
|
||||||
|
|
||||||
|
async def _send_sse_notification(self, notification: Dict[str, Any]):
|
||||||
|
"""发送通知(无需响应)"""
|
||||||
|
if not self._endpoint_url:
|
||||||
|
raise MCPConnectionError("endpoint URL 未初始化")
|
||||||
|
|
||||||
|
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.warning(f"通知发送失败: {response.status}")
|
||||||
|
|
||||||
async def _initialize_modelscope_session(self):
|
async def _initialize_modelscope_session(self):
|
||||||
"""初始化 ModelScope MCP 会话"""
|
"""初始化 ModelScope MCP 会话"""
|
||||||
init_request = {
|
init_request = {
|
||||||
@@ -107,18 +219,12 @@ class SimpleMCPClient:
|
|||||||
"params": {
|
"params": {
|
||||||
"protocolVersion": "2024-11-05",
|
"protocolVersion": "2024-11-05",
|
||||||
"capabilities": {"tools": {}},
|
"capabilities": {"tools": {}},
|
||||||
"clientInfo": {
|
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||||
"name": "MemoryBear",
|
|
||||||
"version": "1.0.0"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.post(
|
async with self._session.post(self.server_url, json=init_request) as response:
|
||||||
self.server_url,
|
|
||||||
json=init_request
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||||
@@ -127,21 +233,16 @@ class SimpleMCPClient:
|
|||||||
if "error" in init_response:
|
if "error" in init_response:
|
||||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||||
|
|
||||||
# 获取 session ID
|
|
||||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||||
if session_id:
|
if session_id:
|
||||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||||
|
|
||||||
# 发送 initialized 通知
|
|
||||||
initialized_notification = {
|
initialized_notification = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"method": "notifications/initialized"
|
"method": "notifications/initialized"
|
||||||
}
|
}
|
||||||
|
|
||||||
async with self._session.post(
|
async with self._session.post(self.server_url, json=initialized_notification):
|
||||||
self.server_url,
|
|
||||||
json=initialized_notification
|
|
||||||
) as notif_response:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
@@ -149,12 +250,18 @@ class SimpleMCPClient:
|
|||||||
|
|
||||||
def _build_headers(self) -> Dict[str, str]:
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
"""构建请求头"""
|
"""构建请求头"""
|
||||||
|
# 基础 headers
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json, text/event-stream"
|
"Accept": "application/json, text/event-stream"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加认证头
|
# 合并 connection_config 中的自定义 headers
|
||||||
|
custom_headers = self.connection_config.get("headers", {})
|
||||||
|
if custom_headers:
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
|
# 处理认证配置(认证 headers 优先级更高)
|
||||||
auth_config = self.connection_config.get("auth_config", {})
|
auth_config = self.connection_config.get("auth_config", {})
|
||||||
auth_type = self.connection_config.get("auth_type", "none")
|
auth_type = self.connection_config.get("auth_type", "none")
|
||||||
|
|
||||||
@@ -178,7 +285,7 @@ class SimpleMCPClient:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def _send_initialize(self):
|
async def _send_initialize(self):
|
||||||
"""发送初始化消息"""
|
"""发送初始化消息(WebSocket)"""
|
||||||
init_message = {
|
init_message = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"id": self._get_request_id(),
|
"id": self._get_request_id(),
|
||||||
@@ -186,124 +293,90 @@ class SimpleMCPClient:
|
|||||||
"params": {
|
"params": {
|
||||||
"protocolVersion": "2024-11-05",
|
"protocolVersion": "2024-11-05",
|
||||||
"capabilities": {"tools": {}},
|
"capabilities": {"tools": {}},
|
||||||
"clientInfo": {
|
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||||
"name": "MemoryBear",
|
|
||||||
"version": "1.0.0"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await self._websocket.send(json.dumps(init_message))
|
await self._websocket.send(json.dumps(init_message))
|
||||||
|
response = await self._websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
|
||||||
# 等待初始化响应
|
if "error" in response_data:
|
||||||
response = await asyncio.wait_for(
|
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
|
||||||
self._websocket.recv(),
|
|
||||||
timeout=self.timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
init_response = json.loads(response)
|
result = response_data.get("result", {})
|
||||||
if "error" in init_response:
|
self._server_capabilities = result.get("capabilities", {})
|
||||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
|
||||||
|
await self._websocket.send(json.dumps({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/initialized"
|
||||||
|
}))
|
||||||
|
|
||||||
|
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||||
|
"""获取工具列表"""
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "tools/list"
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_websocket:
|
||||||
|
await self._websocket.send(json.dumps(request))
|
||||||
|
response = await self._websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
elif self.is_sse:
|
||||||
|
response_data = await self._send_sse_request(request)
|
||||||
|
else:
|
||||||
|
async with self._session.post(self.server_url, json=request) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
if "error" in response_data:
|
||||||
|
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||||
|
|
||||||
|
result = response_data.get("result", {})
|
||||||
|
return result.get("tools", [])
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
|
"""调用工具"""
|
||||||
|
request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {"name": tool_name, "arguments": arguments}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_websocket:
|
||||||
|
await self._websocket.send(json.dumps(request))
|
||||||
|
response = await self._websocket.recv()
|
||||||
|
response_data = json.loads(response)
|
||||||
|
elif self.is_sse:
|
||||||
|
response_data = await self._send_sse_request(request)
|
||||||
|
else:
|
||||||
|
async with self._session.post(self.server_url, json=request) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
if "error" in response_data:
|
||||||
|
error = response_data["error"]
|
||||||
|
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
return response_data.get("result", {})
|
||||||
|
|
||||||
|
def _get_request_id(self) -> int:
|
||||||
|
"""生成请求 ID"""
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
async def _handle_websocket_messages(self):
|
async def _handle_websocket_messages(self):
|
||||||
"""处理 WebSocket 消息"""
|
"""处理 WebSocket 消息"""
|
||||||
try:
|
try:
|
||||||
while self._websocket and not self._websocket.closed:
|
async for message in self._websocket:
|
||||||
try:
|
data = json.loads(message)
|
||||||
message = await self._websocket.recv()
|
request_id = data.get("id")
|
||||||
data = json.loads(message)
|
if request_id and request_id in self._pending_requests:
|
||||||
|
future = self._pending_requests.pop(request_id)
|
||||||
# 处理响应
|
if not future.done():
|
||||||
if "id" in data:
|
future.set_result(data)
|
||||||
request_id = str(data["id"])
|
except ConnectionClosed:
|
||||||
if request_id in self._pending_requests:
|
logger.info("WebSocket 连接已关闭")
|
||||||
future = self._pending_requests.pop(request_id)
|
|
||||||
if not future.done():
|
|
||||||
future.set_result(data)
|
|
||||||
|
|
||||||
except ConnectionClosed:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理WebSocket消息失败: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"WebSocket消息处理异常: {e}")
|
logger.error(f"WebSocket 消息处理错误: {e}")
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
||||||
"""调用工具"""
|
|
||||||
request_data = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_request_id(),
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {
|
|
||||||
"name": tool_name,
|
|
||||||
"arguments": arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.is_websocket:
|
|
||||||
response = await self._send_websocket_request(request_data)
|
|
||||||
else:
|
|
||||||
response = await self._send_http_request(request_data)
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
error = response["error"]
|
|
||||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
|
||||||
|
|
||||||
return response.get("result", {})
|
|
||||||
|
|
||||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
|
||||||
"""获取工具列表"""
|
|
||||||
request_data = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_request_id(),
|
|
||||||
"method": "tools/list",
|
|
||||||
"params": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.is_websocket:
|
|
||||||
response = await self._send_websocket_request(request_data)
|
|
||||||
else:
|
|
||||||
response = await self._send_http_request(request_data)
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
error = response["error"]
|
|
||||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
|
||||||
|
|
||||||
result = response.get("result", {})
|
|
||||||
return result.get("tools", [])
|
|
||||||
|
|
||||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""发送WebSocket请求"""
|
|
||||||
request_id = str(request_data["id"])
|
|
||||||
future = asyncio.Future()
|
|
||||||
self._pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._websocket.send(json.dumps(request_data))
|
|
||||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
|
||||||
return response
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self._pending_requests.pop(request_id, None)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""发送HTTP请求"""
|
|
||||||
try:
|
|
||||||
async with self._session.post(
|
|
||||||
self.server_url,
|
|
||||||
json=request_data
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
error_text = await response.text()
|
|
||||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
|
||||||
|
|
||||||
return await response.json()
|
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
|
||||||
|
|
||||||
def _get_request_id(self) -> str:
|
|
||||||
"""获取请求ID"""
|
|
||||||
self._request_id += 1
|
|
||||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class WorkflowExecutor:
|
|||||||
初始化的工作流状态
|
初始化的工作流状态
|
||||||
"""
|
"""
|
||||||
user_message = input_data.get("message") or ""
|
user_message = input_data.get("message") or ""
|
||||||
|
conversation_messages = input_data.get("conv_messages") or []
|
||||||
|
|
||||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||||
config_variables_list = self.workflow_config.get("variables") or []
|
config_variables_list = self.workflow_config.get("variables") or []
|
||||||
@@ -114,7 +115,7 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [('user', user_message)],
|
"messages": conversation_messages,
|
||||||
"variables": variables,
|
"variables": variables,
|
||||||
"node_outputs": {},
|
"node_outputs": {},
|
||||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||||
|
|||||||
@@ -7,13 +7,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from operator import add
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AnyMessage, AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
from typing_extensions import TypedDict, Annotated
|
from typing_extensions import TypedDict, Annotated
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
|||||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||||
"""
|
"""
|
||||||
# List of messages (append mode)
|
# List of messages (append mode)
|
||||||
messages: Annotated[list[tuple[str, str]], add]
|
messages: list[dict[str, str]]
|
||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
@@ -154,7 +154,7 @@ class BaseNode(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
超时时间
|
超时时间
|
||||||
"""
|
"""
|
||||||
return 60
|
return settings.WORKFLOW_NODE_TIMEOUT
|
||||||
# return self.error_handling.get("timeout", 60)
|
# return self.error_handling.get("timeout", 60)
|
||||||
|
|
||||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
|||||||
# 返回包装后的输出和运行时变量
|
# 返回包装后的输出和运行时变量
|
||||||
return {
|
return {
|
||||||
**wrapped_output,
|
**wrapped_output,
|
||||||
|
"messages": state["messages"],
|
||||||
"variables": state["variables"],
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
@@ -356,6 +357,7 @@ class BaseNode(ABC):
|
|||||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||||
state_update = {
|
state_update = {
|
||||||
**final_output,
|
**final_output,
|
||||||
|
"messages": state["messages"],
|
||||||
"variables": state["variables"],
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ End 节点实现
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
|
|||||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
if output_template:
|
if output_template:
|
||||||
output = self._render_template(output_template, state, strict=False)
|
output = self._render_template(output_template, state, strict=False)
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": output
|
||||||
|
}
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
])
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
# 统计信息(用于日志)
|
||||||
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
|
|||||||
"chunk_index": 1,
|
"chunk_index": 1,
|
||||||
"is_suffix": False
|
"is_suffix": False
|
||||||
})
|
})
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
}
|
||||||
|
])
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
|
|||||||
source_node_id = edge.get("source")
|
source_node_id = edge.get("source")
|
||||||
# Check if the source node is an LLM node
|
# Check if the source node is an LLM node
|
||||||
for node in self.workflow_config.get("nodes", []):
|
for node in self.workflow_config.get("nodes", []):
|
||||||
print("="*50)
|
|
||||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||||
direct_upstream_llm_nodes.append(source_node_id)
|
direct_upstream_llm_nodes.append(source_node_id)
|
||||||
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||||
|
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": output
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
# yield completion marker
|
# yield completion marker
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": output}
|
||||||
return
|
return
|
||||||
|
|
||||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
logger.info(
|
||||||
|
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||||
|
|
||||||
# Collect suffix parts
|
# Collect suffix parts
|
||||||
suffix_parts = []
|
suffix_parts = []
|
||||||
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
|
|||||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||||
full_output = self._render_template(output_template, state, strict=False)
|
full_output = self._render_template(output_template, state, strict=False)
|
||||||
|
|
||||||
|
state['messages'].extend([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.get_variable("sys.message", state)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_output
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||||
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
|
|||||||
})
|
})
|
||||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||||
|
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
|
|||||||
"""消息配置"""
|
"""消息配置"""
|
||||||
|
|
||||||
role: str = Field(
|
role: str = Field(
|
||||||
...,
|
default='user',
|
||||||
description="消息角色:system, user, assistant"
|
description="消息角色:system, user, assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
content: str = Field(
|
content: str = Field(
|
||||||
...,
|
default="",
|
||||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
|
|||||||
return v.lower()
|
return v.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryWindowSetting(BaseModel):
|
||||||
|
enable: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="启用记忆"
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_window: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="启用记忆窗口"
|
||||||
|
)
|
||||||
|
|
||||||
|
window_size: int = Field(
|
||||||
|
default=20,
|
||||||
|
description="记忆窗口大小"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeConfig(BaseNodeConfig):
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
"""LLM 节点配置
|
"""LLM 节点配置
|
||||||
|
|
||||||
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="上下文"
|
description="上下文"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
memory: MemoryWindowSetting = Field(
|
||||||
|
...,
|
||||||
|
description="对话上下文窗口"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. 处理消息格式(优先使用 messages)
|
# 1. 处理消息格式(优先使用 messages)
|
||||||
messages_config = self.config.get("messages")
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.role.lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, state)
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append(SystemMessage(content=content))
|
messages.append({"role": "system", "content": content})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append({"role": "user", "content": content})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append(AIMessage(content=content))
|
messages.append({"role": "assistant", "content": content})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
if self.typed_config.memory.enable:
|
||||||
|
# if self.typed_config.memory.enable_window:
|
||||||
|
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||||
prompt_or_messages = messages
|
prompt_or_messages = messages
|
||||||
else:
|
else:
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
|
|||||||
return {
|
return {
|
||||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||||
for msg in prompt_or_messages
|
for msg in prompt_or_messages
|
||||||
] if isinstance(prompt_or_messages, list) else None,
|
] if isinstance(prompt_or_messages, list) else None,
|
||||||
"config": {
|
"config": {
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from typing import Any
|
|||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
from app.db import get_db_read, get_db_context
|
from app.db import get_db_read
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
|
|||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
|
||||||
end_user_id = self.get_variable("sys.user_id", state)
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
if not workspace_id:
|
|
||||||
raise RuntimeError("Workspace id is required")
|
|
||||||
if not end_user_id:
|
if not end_user_id:
|
||||||
raise RuntimeError("End user id is required")
|
raise RuntimeError("End user id is required")
|
||||||
|
|
||||||
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
|
|||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
with get_db_context() as db:
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
|
||||||
end_user_id = self.get_variable("sys.user_id", state)
|
|
||||||
|
|
||||||
if not workspace_id:
|
if not end_user_id:
|
||||||
raise RuntimeError("Workspace id is required")
|
raise RuntimeError("End user id is required")
|
||||||
if not end_user_id:
|
|
||||||
raise RuntimeError("End user id is required")
|
|
||||||
|
|
||||||
return await MemoryAgentService().write_memory(
|
write_message_task.delay(
|
||||||
group_id=end_user_id,
|
end_user_id,
|
||||||
message=self._render_template(self.typed_config.message, state),
|
self._render_template(self.typed_config.message, state),
|
||||||
config_id=str(self.typed_config.config_id),
|
str(self.typed_config.config_id),
|
||||||
db=db,
|
"neo4j",
|
||||||
storage_type="neo4j",
|
""
|
||||||
user_rag_memory_id=""
|
)
|
||||||
)
|
|
||||||
|
return "success"
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
|
|||||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||||
|
|
||||||
|
|
||||||
class ToolOldConfig(BaseModel):
|
class ToolOldConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
|
|||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunRequest(BaseModel):
|
class DraftRunRequest(BaseModel):
|
||||||
"""试运行请求"""
|
"""试运行请求"""
|
||||||
message: str = Field(..., description="用户消息")
|
message: str = Field(..., description="用户消息")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||||
|
from app.schemas import DraftRunRequest
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -59,7 +60,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.system_prompt
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
@@ -210,7 +211,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.system_prompt
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
@@ -511,7 +512,6 @@ class AppChatService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
except (GeneratorExit, asyncio.CancelledError):
|
except (GeneratorExit, asyncio.CancelledError):
|
||||||
# 生成器被关闭或任务被取消,正常退出
|
# 生成器被关闭或任务被取消,正常退出
|
||||||
logger.debug("多 Agent 流式聊天被中断")
|
logger.debug("多 Agent 流式聊天被中断")
|
||||||
@@ -537,83 +537,19 @@ class AppChatService:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
workflow_service = WorkflowService(self.db)
|
||||||
|
payload = DraftRunRequest(
|
||||||
input_data = {"message":message, "variables": variables,
|
message=message,
|
||||||
"conversation_id": str(conversation_id)}
|
variables=variables,
|
||||||
inconfig = workflow_service.get_workflow_config(app_id)
|
conversation_id=str(conversation_id),
|
||||||
|
stream=True,
|
||||||
# 2. 创建执行记录
|
user_id=user_id
|
||||||
execution = workflow_service.create_execution(
|
)
|
||||||
workflow_config_id=inconfig.id,
|
return await workflow_service.run(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
payload=payload,
|
||||||
triggered_by=None,
|
config=config,
|
||||||
conversation_id=conversation_id,
|
workspace_id=workspace_id,
|
||||||
input_data=input_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 构建工作流配置字典
|
|
||||||
workflow_config_dict = {
|
|
||||||
"nodes": config.nodes,
|
|
||||||
"edges": config.edges,
|
|
||||||
"variables": config.variables,
|
|
||||||
"execution_config": config.execution_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 4. 获取工作空间 ID(从 app 获取)
|
|
||||||
|
|
||||||
# 5. 执行工作流
|
|
||||||
from app.core.workflow.executor import execute_workflow
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新状态为运行中
|
|
||||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
|
||||||
|
|
||||||
result = await execute_workflow(
|
|
||||||
workflow_config=workflow_config_dict,
|
|
||||||
input_data=input_data,
|
|
||||||
execution_id=execution.execution_id,
|
|
||||||
workspace_id=str(workspace_id),
|
|
||||||
user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新执行结果
|
|
||||||
if result.get("status") == "completed":
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"completed",
|
|
||||||
output_data=result.get("node_outputs", {})
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=result.get("error")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回增强的响应结构
|
|
||||||
return {
|
|
||||||
"execution_id": execution.execution_id,
|
|
||||||
"status": result.get("status"),
|
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
|
||||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
|
||||||
"error_message": result.get("error"),
|
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
|
||||||
"token_usage": result.get("token_usage")
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=str(e)
|
|
||||||
)
|
|
||||||
raise BusinessException(
|
|
||||||
code=BizCode.INTERNAL_ERROR,
|
|
||||||
message=f"工作流执行失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def workflow_chat_stream(
|
async def workflow_chat_stream(
|
||||||
self,
|
self,
|
||||||
@@ -632,62 +568,21 @@ class AppChatService:
|
|||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
workflow_service = WorkflowService(self.db)
|
workflow_service = WorkflowService(self.db)
|
||||||
input_data = {"message": message, "variables": variables,
|
payload = DraftRunRequest(
|
||||||
"conversation_id": str(conversation_id)}
|
message=message,
|
||||||
inconfig = workflow_service.get_workflow_config(app_id)
|
variables=variables,
|
||||||
# 2. 创建执行记录
|
conversation_id=str(conversation_id),
|
||||||
execution = workflow_service.create_execution(
|
stream=True,
|
||||||
workflow_config_id=inconfig.id,
|
user_id=user_id
|
||||||
app_id=app_id,
|
|
||||||
trigger_type="manual",
|
|
||||||
triggered_by=None,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
input_data=input_data
|
|
||||||
)
|
)
|
||||||
|
async for event in workflow_service.run_stream(
|
||||||
|
app_id=app_id,
|
||||||
|
payload=payload,
|
||||||
|
config=config,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
# 3. 构建工作流配置字典
|
|
||||||
workflow_config_dict = {
|
|
||||||
"nodes": config.nodes,
|
|
||||||
"edges": config.edges,
|
|
||||||
"variables": config.variables,
|
|
||||||
"execution_config": config.execution_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 4. 获取工作空间 ID(从 app 获取)
|
|
||||||
|
|
||||||
# 5. 流式执行工作流
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新状态为运行中
|
|
||||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
|
||||||
|
|
||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
|
||||||
async for event in workflow_service._run_workflow_stream(
|
|
||||||
workflow_config=workflow_config_dict,
|
|
||||||
input_data=input_data,
|
|
||||||
execution_id=execution.execution_id,
|
|
||||||
workspace_id=str(workspace_id),
|
|
||||||
user_id=user_id
|
|
||||||
):
|
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
|
||||||
yield event
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
|
||||||
workflow_service.update_execution_status(
|
|
||||||
execution.execution_id,
|
|
||||||
"failed",
|
|
||||||
error_message=str(e)
|
|
||||||
)
|
|
||||||
# 发送错误事件
|
|
||||||
yield {
|
|
||||||
"event": "error",
|
|
||||||
"data": {
|
|
||||||
"execution_id": execution.execution_id,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
|
from app.repositories.conversation_repository import ConversationRepository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import type_mapping, EmotionType, EmotionSubject
|
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||||
|
|
||||||
from app.services.memory_base_service import MemoryBaseService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
from app.services.memory_short_service import ShortService
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -1198,18 +1200,17 @@ async def analytics_memory_types(
|
|||||||
end_user_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
统计9种记忆类型的数量和百分比
|
统计8种记忆类型的数量和百分比
|
||||||
|
|
||||||
计算规则:
|
计算规则:
|
||||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
|
||||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
||||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
||||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||||
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
|
||||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||||
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
@@ -1229,7 +1230,6 @@ async def analytics_memory_types(
|
|||||||
- PERCEPTUAL_MEMORY: 感知记忆
|
- PERCEPTUAL_MEMORY: 感知记忆
|
||||||
- WORKING_MEMORY: 工作记忆
|
- WORKING_MEMORY: 工作记忆
|
||||||
- SHORT_TERM_MEMORY: 短期记忆
|
- SHORT_TERM_MEMORY: 短期记忆
|
||||||
- LONG_TERM_MEMORY: 长期记忆
|
|
||||||
- EXPLICIT_MEMORY: 显性记忆
|
- EXPLICIT_MEMORY: 显性记忆
|
||||||
- IMPLICIT_MEMORY: 隐性记忆
|
- IMPLICIT_MEMORY: 隐性记忆
|
||||||
- EMOTIONAL_MEMORY: 情绪记忆
|
- EMOTIONAL_MEMORY: 情绪记忆
|
||||||
@@ -1239,40 +1239,78 @@ async def analytics_memory_types(
|
|||||||
# 初始化基础服务
|
# 初始化基础服务
|
||||||
base_service = MemoryBaseService()
|
base_service = MemoryBaseService()
|
||||||
|
|
||||||
# 定义需要查询的基础节点类型
|
# 初始化感知记忆服务
|
||||||
node_types = {
|
perceptual_service = MemoryPerceptualService(db)
|
||||||
"Statement": "Statement",
|
|
||||||
"Entity": "ExtractedEntity",
|
|
||||||
"Chunk": "Chunk"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 存储每种节点类型的计数
|
# 获取感知记忆数量
|
||||||
node_counts = {}
|
if end_user_id:
|
||||||
|
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
|
||||||
|
perceptual_count = perceptual_stats.get("total", 0)
|
||||||
|
else:
|
||||||
|
perceptual_count = 0
|
||||||
|
|
||||||
# 查询每种节点类型的数量
|
# 获取工作记忆数量(基于会话数量)
|
||||||
for key, node_type in node_types.items():
|
work_count = 0
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
query = f"""
|
try:
|
||||||
MATCH (n:{node_type})
|
conversation_repo = ConversationRepository(db)
|
||||||
|
conversations = conversation_repo.get_conversation_by_user_id(
|
||||||
|
user_id=uuid.UUID(end_user_id),
|
||||||
|
limit=100, # 获取更多会话以准确统计
|
||||||
|
is_activate=True
|
||||||
|
)
|
||||||
|
work_count = len(conversations)
|
||||||
|
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||||
|
work_count = 0
|
||||||
|
|
||||||
|
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
|
||||||
|
implicit_count = 0
|
||||||
|
if end_user_id:
|
||||||
|
try:
|
||||||
|
# 查询 Statement 节点数量
|
||||||
|
query = """
|
||||||
|
MATCH (n:Statement)
|
||||||
WHERE n.group_id = $group_id
|
WHERE n.group_id = $group_id
|
||||||
RETURN count(n) as count
|
RETURN count(n) as count
|
||||||
"""
|
"""
|
||||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||||
else:
|
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||||
query = f"""
|
# 取三分之一作为隐性记忆数量
|
||||||
MATCH (n:{node_type})
|
implicit_count = round(statement_count / 3)
|
||||||
RETURN count(n) as count
|
logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
|
||||||
"""
|
except Exception as e:
|
||||||
result = await _neo4j_connector.execute_query(query)
|
logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}")
|
||||||
|
implicit_count = 0
|
||||||
|
|
||||||
# 提取计数结果
|
# 原有的基于行为习惯的统计方式(已注释)
|
||||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
# implicit_count = 0
|
||||||
node_counts[key] = count
|
# if end_user_id:
|
||||||
|
# try:
|
||||||
|
# implicit_service = ImplicitMemoryService(db, end_user_id)
|
||||||
|
# behavior_habits = await implicit_service.get_behavior_habits(
|
||||||
|
# user_id=end_user_id
|
||||||
|
# )
|
||||||
|
# implicit_count = len(behavior_habits)
|
||||||
|
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}")
|
||||||
|
# implicit_count = 0
|
||||||
|
|
||||||
# 获取各节点类型的数量
|
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
|
||||||
statement_count = node_counts.get("Statement", 0)
|
short_term_count = 0
|
||||||
entity_count = node_counts.get("Entity", 0)
|
if end_user_id:
|
||||||
chunk_count = node_counts.get("Chunk", 0)
|
try:
|
||||||
|
short_term_service = ShortService(end_user_id)
|
||||||
|
short_term_data = short_term_service.get_short_databasets()
|
||||||
|
# 统计 short_term 数组的长度
|
||||||
|
if short_term_data:
|
||||||
|
short_term_count = len(short_term_data)
|
||||||
|
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}")
|
||||||
|
short_term_count = 0
|
||||||
|
|
||||||
# 获取用户的遗忘阈值配置
|
# 获取用户的遗忘阈值配置
|
||||||
forgetting_threshold = 0.3 # 默认值
|
forgetting_threshold = 0.3 # 默认值
|
||||||
@@ -1298,17 +1336,16 @@ async def analytics_memory_types(
|
|||||||
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
||||||
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
||||||
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
||||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
|
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
|
||||||
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
||||||
|
|
||||||
# 按规则计算9种记忆类型的数量(使用英文枚举作为key)
|
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||||
memory_counts = {
|
memory_counts = {
|
||||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
|
||||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
||||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
||||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
|
||||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3)
|
||||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||||
|
|||||||
@@ -2,12 +2,11 @@
|
|||||||
工作流服务层
|
工作流服务层
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
|
||||||
from typing import Any, Annotated, AsyncGenerator
|
from typing import Any, Annotated, AsyncGenerator
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories.conversation_repository import MessageRepository
|
||||||
|
from app.models.conversation_model import Message
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest
|
||||||
from app.utils.sse_utils import format_sse_message
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ class WorkflowService:
|
|||||||
self.config_repo = WorkflowConfigRepository(db)
|
self.config_repo = WorkflowConfigRepository(db)
|
||||||
self.execution_repo = WorkflowExecutionRepository(db)
|
self.execution_repo = WorkflowExecutionRepository(db)
|
||||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||||
|
self.message_repo = MessageRepository(db)
|
||||||
|
|
||||||
# ==================== 配置管理 ====================
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
@@ -418,14 +419,13 @@ class WorkflowService:
|
|||||||
"""运行工作流
|
"""运行工作流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
workspace_id:
|
||||||
|
config:
|
||||||
|
payload:
|
||||||
app_id: 应用 ID
|
app_id: 应用 ID
|
||||||
input_data: 输入数据(包含 message 和 variables)
|
|
||||||
triggered_by: 触发用户 ID
|
|
||||||
conversation_id: 会话 ID(可选)
|
|
||||||
stream: 是否流式返回
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行结果(非流式)或生成器(流式)
|
执行结果(非流式)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
BusinessException: 配置不存在或执行失败时抛出
|
BusinessException: 配置不存在或执行失败时抛出
|
||||||
@@ -438,7 +438,8 @@ class WorkflowService:
|
|||||||
code=BizCode.CONFIG_MISSING,
|
code=BizCode.CONFIG_MISSING,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
input_data = {"message": payload.message, "variables": payload.variables,
|
||||||
|
"conversation_id": payload.conversation_id}
|
||||||
|
|
||||||
# 转换 user_id 为 UUID
|
# 转换 user_id 为 UUID
|
||||||
triggered_by_uuid = None
|
triggered_by_uuid = None
|
||||||
@@ -461,7 +462,7 @@ class WorkflowService:
|
|||||||
workflow_config_id=config.id,
|
workflow_config_id=config.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
trigger_type="manual",
|
||||||
triggered_by=triggered_by_uuid,
|
triggered_by=None,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
input_data=input_data
|
input_data=input_data
|
||||||
)
|
)
|
||||||
@@ -500,8 +501,11 @@ class WorkflowService:
|
|||||||
variables = last_state.get("variables", {})
|
variables = last_state.get("variables", {})
|
||||||
conv_vars = variables.get("conv", {})
|
conv_vars = variables.get("conv", {})
|
||||||
input_data["conv"] = conv_vars
|
input_data["conv"] = conv_vars
|
||||||
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
|
|
||||||
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
@@ -517,6 +521,17 @@ class WorkflowService:
|
|||||||
"completed",
|
"completed",
|
||||||
output_data=result
|
output_data=result
|
||||||
)
|
)
|
||||||
|
final_messages = result.get("messages", [])[init_message_length:]
|
||||||
|
for message in final_messages:
|
||||||
|
message_obj = Message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role=message["role"],
|
||||||
|
content=message["content"],
|
||||||
|
)
|
||||||
|
self.message_repo.add_message(message_obj)
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Workflow Run Success, "
|
||||||
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
else:
|
else:
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
@@ -529,6 +544,7 @@ class WorkflowService:
|
|||||||
"execution_id": execution.execution_id,
|
"execution_id": execution.execution_id,
|
||||||
"status": result.get("status"),
|
"status": result.get("status"),
|
||||||
"variables": result.get("variables"),
|
"variables": result.get("variables"),
|
||||||
|
"messages": result.get("messages"),
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
@@ -559,6 +575,7 @@ class WorkflowService:
|
|||||||
"""运行工作流(流式)
|
"""运行工作流(流式)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
workspace_id:
|
||||||
app_id: 应用 ID
|
app_id: 应用 ID
|
||||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||||
config: 存储类型(可选)
|
config: 存储类型(可选)
|
||||||
@@ -601,7 +618,7 @@ class WorkflowService:
|
|||||||
workflow_config_id=config.id,
|
workflow_config_id=config.id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
trigger_type="manual",
|
trigger_type="manual",
|
||||||
triggered_by=triggered_by_uuid,
|
triggered_by=None,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
input_data=input_data
|
input_data=input_data
|
||||||
)
|
)
|
||||||
@@ -638,17 +655,46 @@ class WorkflowService:
|
|||||||
variables = last_state.get("variables", {})
|
variables = last_state.get("variables", {})
|
||||||
conv_vars = variables.get("conv", {})
|
conv_vars = variables.get("conv", {})
|
||||||
input_data["conv"] = conv_vars
|
input_data["conv"] = conv_vars
|
||||||
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
from app.core.workflow.executor import execute_workflow_stream
|
||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
async for event in execute_workflow_stream(
|
||||||
async for event in self._run_workflow_stream(
|
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=end_user_id
|
user_id=end_user_id
|
||||||
):
|
):
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
if event.get("event") == "workflow_end":
|
||||||
|
|
||||||
|
status = event.get("data", {}).get("status")
|
||||||
|
if status == "completed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"completed",
|
||||||
|
output_data=event.get("data")
|
||||||
|
)
|
||||||
|
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||||
|
for message in final_messages:
|
||||||
|
message_obj = Message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role=message["role"],
|
||||||
|
content=message["content"],
|
||||||
|
)
|
||||||
|
self.message_repo.add_message(message_obj)
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Workflow Run Success, "
|
||||||
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
|
elif status == "failed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
output_data=event.get("data")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"unexpect workflow run status, status: {status}")
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -667,6 +713,8 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated(reason="This method is deprecated. "
|
||||||
|
"Please use WorkflowService.run / run_stream instead.")
|
||||||
async def run_workflow(
|
async def run_workflow(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -819,6 +867,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
return clean_value(event)
|
return clean_value(event)
|
||||||
|
|
||||||
|
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||||
async def _run_workflow_stream(
|
async def _run_workflow_stream(
|
||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
|
|||||||
@@ -136,7 +136,8 @@ dependencies = [
|
|||||||
"markdown-to-json==2.1.1",
|
"markdown-to-json==2.1.1",
|
||||||
"valkey==6.0.2",
|
"valkey==6.0.2",
|
||||||
"python-calamine>=0.4.0",
|
"python-calamine>=0.4.0",
|
||||||
"xlrd==2.0.2"
|
"xlrd==2.0.2",
|
||||||
|
"deprecated>=1.3.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
{/* 消息气泡框 */}
|
{/* 消息气泡框 */}
|
||||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, {
|
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
|
||||||
// 错误消息样式(内容为null且非助手消息)
|
// 错误消息样式(内容为null且非助手消息)
|
||||||
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
|
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
|
||||||
// 助手消息样式
|
// 助手消息样式
|
||||||
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
{/* 底部标签(如时间戳、用户名等) */}
|
{/* 底部标签(如时间戳、用户名等) */}
|
||||||
{labelPosition === 'bottom' &&
|
{labelPosition === 'bottom' &&
|
||||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular">
|
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
|
||||||
{labelFormat(item)}
|
{labelFormat(item)}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1265,6 +1265,7 @@ export const en = {
|
|||||||
emotionLine: 'Emotion Changes Over Time',
|
emotionLine: 'Emotion Changes Over Time',
|
||||||
interaction: 'Interaction Frequency & Relationship Stages',
|
interaction: 'Interaction Frequency & Relationship Stages',
|
||||||
timelines_memory: 'All',
|
timelines_memory: 'All',
|
||||||
|
Chunk: 'Chunk',
|
||||||
MemorySummary: 'Long-term Accumulation',
|
MemorySummary: 'Long-term Accumulation',
|
||||||
Statement: 'Emotional Memory',
|
Statement: 'Emotional Memory',
|
||||||
ExtractedEntity: 'Episodic Memory',
|
ExtractedEntity: 'Episodic Memory',
|
||||||
@@ -1786,6 +1787,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
temperature: 'Temperature',
|
temperature: 'Temperature',
|
||||||
max_tokens: 'Max Tokens',
|
max_tokens: 'Max Tokens',
|
||||||
context: 'Context',
|
context: 'Context',
|
||||||
|
memory: 'Memory',
|
||||||
|
enable_window: 'Memory Window',
|
||||||
|
inner: 'Built-in',
|
||||||
},
|
},
|
||||||
start: {
|
start: {
|
||||||
variables: 'Input Fields',
|
variables: 'Input Fields',
|
||||||
|
|||||||
@@ -1343,6 +1343,7 @@ export const zh = {
|
|||||||
emotionLine: '情绪随时间变化',
|
emotionLine: '情绪随时间变化',
|
||||||
interaction: '互动频率 & 关系阶段',
|
interaction: '互动频率 & 关系阶段',
|
||||||
timelines_memory: '全部',
|
timelines_memory: '全部',
|
||||||
|
Chunk: '工作记忆',
|
||||||
MemorySummary: '长期沉淀',
|
MemorySummary: '长期沉淀',
|
||||||
Statement: '情绪记忆',
|
Statement: '情绪记忆',
|
||||||
ExtractedEntity: '情景记忆',
|
ExtractedEntity: '情景记忆',
|
||||||
@@ -1883,6 +1884,9 @@ export const zh = {
|
|||||||
temperature: '温度',
|
temperature: '温度',
|
||||||
max_tokens: '最大令牌数',
|
max_tokens: '最大令牌数',
|
||||||
context: '上下文',
|
context: '上下文',
|
||||||
|
memory: '记忆',
|
||||||
|
enable_window: '记忆窗口',
|
||||||
|
inner: '内置',
|
||||||
},
|
},
|
||||||
start: {
|
start: {
|
||||||
variables: '输入字段',
|
variables: '输入字段',
|
||||||
|
|||||||
@@ -176,6 +176,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
|||||||
if (response?.knowledge_retrieval?.knowledge_bases?.length) {
|
if (response?.knowledge_retrieval?.knowledge_bases?.length) {
|
||||||
getDefaultKnowledgeList(response)
|
getDefaultKnowledgeList(response)
|
||||||
}
|
}
|
||||||
|
if (response?.tools?.length) {
|
||||||
|
setToolList(response?.tools)
|
||||||
|
}
|
||||||
}).finally(() => {
|
}).finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -79,8 +79,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
|
|||||||
}
|
}
|
||||||
}, [data])
|
}, [data])
|
||||||
|
|
||||||
console.log('toolList', toolList)
|
|
||||||
|
|
||||||
const handleAddTool = () => {
|
const handleAddTool = () => {
|
||||||
toolModalRef.current?.handleOpen()
|
toolModalRef.current?.handleOpen()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,9 +259,10 @@ const Conversation: FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
|
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
|
||||||
|
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
|
||||||
<Chat
|
<Chat
|
||||||
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />}
|
empty={<Empty url={BgImg} className="rb:h-full" size={[320,180]} subTitle={t('memoryConversation.emptyDesc')} />}
|
||||||
contentClassName="rb:h-[calc(100%-152px)]"
|
contentClassName="rb:h-[calc(100%-152px)] "
|
||||||
data={chatList}
|
data={chatList}
|
||||||
streamLoading={streamLoading}
|
streamLoading={streamLoading}
|
||||||
loading={loading}
|
loading={loading}
|
||||||
@@ -290,6 +291,7 @@ const Conversation: FC = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Form>
|
</Form>
|
||||||
</Chat>
|
</Chat>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Flex>
|
</Flex>
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import React, { useState, useImperativeHandle, forwardRef, useRef } from 'react';
|
import { useState, useImperativeHandle, forwardRef, useRef } from 'react';
|
||||||
import { Button, Input, Space, Typography, Tooltip, message, List } from 'antd';
|
import { Button, Space, List } from 'antd';
|
||||||
import { PlusOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { ChatVariable, AddChatVariableRef } from '../../types';
|
import type { ChatVariable, AddChatVariableRef } from '../../types';
|
||||||
import type { ChatVariableModalRef } from './types'
|
import type { ChatVariableModalRef } from './types'
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ const EditableTable: React.FC<EditableTableProps> = ({
|
|||||||
const AddButton = ({ block = false }: { block?: boolean }) => (
|
const AddButton = ({ block = false }: { block?: boolean }) => (
|
||||||
<Button
|
<Button
|
||||||
type={block ? "dashed" : "text"}
|
type={block ? "dashed" : "text"}
|
||||||
icon={<PlusOutlined />}
|
icon={block ? undefined : <PlusOutlined />}
|
||||||
onClick={() => add(createNewRow())}
|
onClick={() => add(createNewRow())}
|
||||||
size="small"
|
size="small"
|
||||||
block={block}
|
block={block}
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
import { type FC } from "react";
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { Form, Row, Col, Divider, Switch, Slider } from 'antd'
|
||||||
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
|
import MessageEditor from '../MessageEditor'
|
||||||
|
|
||||||
|
const MemoryConfig: FC<{ options: Suggestion[]; parentName: string; }> = ({
|
||||||
|
options,
|
||||||
|
parentName
|
||||||
|
}) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const form = Form.useFormInstance();
|
||||||
|
const values = Form.useWatch([], form) || {}
|
||||||
|
|
||||||
|
console.log('MemoryConfig', values)
|
||||||
|
|
||||||
|
const handleChangeEnable = (value: boolean) => {
|
||||||
|
if (value) {
|
||||||
|
form.setFieldsValue({
|
||||||
|
memory: {
|
||||||
|
...form.getFieldValue(parentName),
|
||||||
|
enable_window: false,
|
||||||
|
window_size: 20,
|
||||||
|
messages: "{{sys.message}}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{values?.memory?.enable && <>
|
||||||
|
<div className="rb:flex rb:items-center rb:justify-between rb:py-1.5 rb:px-2 rb:bg-[#F6F8FC] rb:rounded-md rb:mb-2">
|
||||||
|
{t('workflow.config.llm.memory')}
|
||||||
|
<span>{t('workflow.config.llm.inner')}</span>
|
||||||
|
</div>
|
||||||
|
<Form.Item layout="horizontal" name={[parentName, 'messages']}>
|
||||||
|
<MessageEditor
|
||||||
|
title="USER"
|
||||||
|
isArray={false}
|
||||||
|
parentName={[parentName, 'messages']}
|
||||||
|
options={options}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Divider />
|
||||||
|
</>}
|
||||||
|
<Form.Item layout="horizontal" name={[parentName, 'enable']} label={t('workflow.config.llm.memory')}>
|
||||||
|
<Switch onChange={handleChangeEnable} />
|
||||||
|
</Form.Item>
|
||||||
|
{values?.memory?.enable && <>
|
||||||
|
<Row className="rb:mb-3">
|
||||||
|
<Col span={10}>
|
||||||
|
<Form.Item layout="horizontal" name={[parentName, 'enable_window']} noStyle>
|
||||||
|
<Switch />
|
||||||
|
</Form.Item>
|
||||||
|
<span className="rb:ml-2">{t('workflow.config.llm.enable_window')}</span>
|
||||||
|
</Col>
|
||||||
|
<Col span={14}>
|
||||||
|
<Form.Item layout="horizontal" name={[parentName, 'window_size']} noStyle>
|
||||||
|
<Slider min={1} max={100} step={1} className="rb:my-0!" disabled={!values?.memory?.enable_window} />
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
</>}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
export default MemoryConfig;
|
||||||
@@ -127,7 +127,7 @@ const MessageEditor: FC<MessageEditor> = ({
|
|||||||
</Space>
|
</Space>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
<Form.Item>
|
<Form.Item noStyle>
|
||||||
<Button type="dashed" onClick={() => handleAdd(add)} block>
|
<Button type="dashed" onClick={() => handleAdd(add)} block>
|
||||||
+{t('workflow.addMessage')}
|
+{t('workflow.addMessage')}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
|
|||||||
import CycleVarsList from './CycleVarsList'
|
import CycleVarsList from './CycleVarsList'
|
||||||
import AssignmentList from './AssignmentList'
|
import AssignmentList from './AssignmentList'
|
||||||
import ToolConfig from './ToolConfig'
|
import ToolConfig from './ToolConfig'
|
||||||
|
import MemoryConfig from './MemoryConfig'
|
||||||
// import { calculateVariableList } from './utils/variableListCalculator'
|
// import { calculateVariableList } from './utils/variableListCalculator'
|
||||||
|
|
||||||
interface PropertiesProps {
|
interface PropertiesProps {
|
||||||
@@ -1230,6 +1231,20 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
</Form.Item>
|
</Form.Item>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if (config.type === 'memoryConfig') {
|
||||||
|
return (
|
||||||
|
<Form.Item
|
||||||
|
key={key}
|
||||||
|
name={key}
|
||||||
|
noStyle
|
||||||
|
>
|
||||||
|
<MemoryConfig
|
||||||
|
parentName={key}
|
||||||
|
options={getFilteredVariableList('llm')}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Form.Item
|
<Form.Item
|
||||||
|
|||||||
@@ -135,6 +135,14 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
readonly: true
|
readonly: true
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
memory: {
|
||||||
|
type: 'memoryConfig',
|
||||||
|
defaultValue: {
|
||||||
|
enable: false,
|
||||||
|
enable_window: false,
|
||||||
|
window_size: 20
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -750,10 +758,6 @@ export const outputVariable: { [key: string]: OutputVariable } = {
|
|||||||
{ name: "body", type: "string" },
|
{ name: "body", type: "string" },
|
||||||
{ name: "status_code", type: "number" },
|
{ name: "status_code", type: "number" },
|
||||||
],
|
],
|
||||||
error: [
|
|
||||||
{ name: "error_message", type: "string" },
|
|
||||||
{ name: "error_type", type: "string" },
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
'tool': {
|
'tool': {
|
||||||
default: [
|
default: [
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
|
|||||||
import { register } from '@antv/x6-react-shape';
|
import { register } from '@antv/x6-react-shape';
|
||||||
|
|
||||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
|
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
|
||||||
import type { WorkflowConfig, NodeProperties } from '../types';
|
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
|
|
||||||
@@ -35,6 +35,8 @@ export interface UseWorkflowGraphReturn {
|
|||||||
copyEvent: () => boolean | void;
|
copyEvent: () => boolean | void;
|
||||||
parseEvent: () => boolean | void;
|
parseEvent: () => boolean | void;
|
||||||
handleSave: (flag?: boolean) => Promise<unknown>;
|
handleSave: (flag?: boolean) => Promise<unknown>;
|
||||||
|
chatVariables: ChatVariable[];
|
||||||
|
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const edge_color = '#155EEF';
|
export const edge_color = '#155EEF';
|
||||||
@@ -54,6 +56,7 @@ export const useWorkflowGraph = ({
|
|||||||
const [canRedo, setCanRedo] = useState(false);
|
const [canRedo, setCanRedo] = useState(false);
|
||||||
const [isHandMode, setIsHandMode] = useState(false);
|
const [isHandMode, setIsHandMode] = useState(false);
|
||||||
const [config, setConfig] = useState<WorkflowConfig | null>(null);
|
const [config, setConfig] = useState<WorkflowConfig | null>(null);
|
||||||
|
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
getConfig()
|
getConfig()
|
||||||
@@ -63,16 +66,15 @@ export const useWorkflowGraph = ({
|
|||||||
getWorkflowConfig(id)
|
getWorkflowConfig(id)
|
||||||
.then(res => {
|
.then(res => {
|
||||||
const { variables, ...rest } = res as WorkflowConfig
|
const { variables, ...rest } = res as WorkflowConfig
|
||||||
setConfig({
|
const initChatVariables = variables.map(v => {
|
||||||
...rest,
|
const { default: _, ...cleanV } = v
|
||||||
variables: variables.map(v => {
|
return {
|
||||||
const { default: _, ...cleanV } = v
|
...cleanV,
|
||||||
return {
|
defaultValue: v.default ?? ''
|
||||||
...cleanV,
|
}
|
||||||
defaultValue: v.default ?? ''
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
setChatVariables(initChatVariables)
|
||||||
|
setConfig({ ...rest, variables: initChatVariables })
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,7 +96,17 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
if (nodeLibraryConfig?.config) {
|
if (nodeLibraryConfig?.config) {
|
||||||
Object.keys(nodeLibraryConfig.config).forEach(key => {
|
Object.keys(nodeLibraryConfig.config).forEach(key => {
|
||||||
if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||||
|
const { memory, messages } = config as any;
|
||||||
|
if (memory?.enable && messages && messages.length > 0) {
|
||||||
|
const lastMessage = messages[messages.length - 1]
|
||||||
|
nodeLibraryConfig.config[key].defaultValue = {
|
||||||
|
...memory,
|
||||||
|
messages: lastMessage.content
|
||||||
|
}
|
||||||
|
nodeLibraryConfig.config.messages.defaultValue.splice(-1, 1)
|
||||||
|
}
|
||||||
|
} else if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||||
const { query, ...rest } = config
|
const { query, ...rest } = config
|
||||||
nodeLibraryConfig.config[key].defaultValue = {
|
nodeLibraryConfig.config[key].defaultValue = {
|
||||||
...rest
|
...rest
|
||||||
@@ -917,13 +929,13 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
const params = {
|
const params = {
|
||||||
...config,
|
...config,
|
||||||
variables: config.variables.map(v => {
|
variables: chatVariables.map(v => {
|
||||||
const { defaultValue, ...cleanV } = v
|
const { defaultValue, ...cleanV } = v
|
||||||
return {
|
return {
|
||||||
...cleanV,
|
...cleanV,
|
||||||
default: defaultValue ?? ''
|
default: defaultValue ?? ''
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
nodes: nodes.map((node: Node) => {
|
nodes: nodes.map((node: Node) => {
|
||||||
const data = node.getData();
|
const data = node.getData();
|
||||||
const position = node.getPosition();
|
const position = node.getPosition();
|
||||||
@@ -931,7 +943,15 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
if (data.config) {
|
if (data.config) {
|
||||||
Object.keys(data.config).forEach(key => {
|
Object.keys(data.config).forEach(key => {
|
||||||
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
|
if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
|
||||||
|
const { messages, ...rest } = data.config[key].defaultValue
|
||||||
|
let memoryMessage = { role: 'USER', content: data.config[key].defaultValue.messages }
|
||||||
|
itemConfig = {
|
||||||
|
...itemConfig,
|
||||||
|
messages: rest.enable ? [...itemConfig.messages, memoryMessage] : itemConfig.messages,
|
||||||
|
memory: { ...rest },
|
||||||
|
}
|
||||||
|
} else if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
|
||||||
let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue
|
let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue
|
||||||
if (data.config.group.defaultValue) {
|
if (data.config.group.defaultValue) {
|
||||||
data.config[key].defaultValue.map((vo: any) => {
|
data.config[key].defaultValue.map((vo: any) => {
|
||||||
@@ -1077,5 +1097,7 @@ export const useWorkflowGraph = ({
|
|||||||
copyEvent,
|
copyEvent,
|
||||||
parseEvent,
|
parseEvent,
|
||||||
handleSave,
|
handleSave,
|
||||||
|
chatVariables,
|
||||||
|
setChatVariables
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import PortClickHandler from './components/PortClickHandler';
|
|||||||
import { useWorkflowGraph } from './hooks/useWorkflowGraph';
|
import { useWorkflowGraph } from './hooks/useWorkflowGraph';
|
||||||
import type { WorkflowRef } from '@/views/ApplicationConfig/types'
|
import type { WorkflowRef } from '@/views/ApplicationConfig/types'
|
||||||
import Chat from './components/Chat/Chat';
|
import Chat from './components/Chat/Chat';
|
||||||
import type { ChatRef, AddChatVariableRef, ChatVariable } from './types'
|
import type { ChatRef, AddChatVariableRef } from './types'
|
||||||
import arrowIcon from '@/assets/images/workflow/arrow.png'
|
import arrowIcon from '@/assets/images/workflow/arrow.png'
|
||||||
import AddChatVariable from './components/AddChatVariable';
|
import AddChatVariable from './components/AddChatVariable';
|
||||||
|
|
||||||
@@ -21,7 +21,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
|||||||
// 使用自定义Hook初始化工作流图
|
// 使用自定义Hook初始化工作流图
|
||||||
const {
|
const {
|
||||||
config,
|
config,
|
||||||
setConfig,
|
|
||||||
graphRef,
|
graphRef,
|
||||||
selectedNode,
|
selectedNode,
|
||||||
setSelectedNode,
|
setSelectedNode,
|
||||||
@@ -38,6 +37,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
|||||||
copyEvent,
|
copyEvent,
|
||||||
parseEvent,
|
parseEvent,
|
||||||
handleSave,
|
handleSave,
|
||||||
|
chatVariables,
|
||||||
|
setChatVariables
|
||||||
} = useWorkflowGraph({ containerRef, miniMapRef });
|
} = useWorkflowGraph({ containerRef, miniMapRef });
|
||||||
|
|
||||||
const onDragOver = (event: React.DragEvent) => {
|
const onDragOver = (event: React.DragEvent) => {
|
||||||
@@ -52,15 +53,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
|||||||
const addVariable = () => {
|
const addVariable = () => {
|
||||||
addChatVariableRef.current?.handleOpen()
|
addChatVariableRef.current?.handleOpen()
|
||||||
}
|
}
|
||||||
const handleUpdateChatVariable = (variables: ChatVariable[]) => {
|
|
||||||
setConfig(prev => {
|
|
||||||
if (!prev) return null
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
variables
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
useImperativeHandle(ref, () => ({
|
useImperativeHandle(ref, () => ({
|
||||||
handleSave,
|
handleSave,
|
||||||
@@ -125,8 +117,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
|||||||
|
|
||||||
<AddChatVariable
|
<AddChatVariable
|
||||||
ref={addChatVariableRef}
|
ref={addChatVariableRef}
|
||||||
variables={config?.variables}
|
variables={chatVariables}
|
||||||
onChange={handleUpdateChatVariable}
|
onChange={setChatVariables}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user