Merge branch 'refs/heads/develop' into fix/memory_bug_fix

# Conflicts:
#	api/app/services/user_memory_service.py
This commit is contained in:
lixinyue
2026-01-14 18:25:47 +08:00
31 changed files with 731 additions and 520 deletions

View File

@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
@router.post("/{app_id}/workflow") @router.post("/{app_id}/workflow")
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def create_workflow_config( async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate, config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""创建工作流配置 """创建工作流配置
@@ -96,6 +96,7 @@ async def create_workflow_config(
msg=f"创建工作流配置失败: {str(e)}" msg=f"创建工作流配置失败: {str(e)}"
) )
# #
# @router.get("/{app_id}/workflow") # @router.get("/{app_id}/workflow")
# async def get_workflow_config( # async def get_workflow_config(
@@ -199,10 +200,10 @@ async def create_workflow_config(
@router.delete("/{app_id}/workflow") @router.delete("/{app_id}/workflow")
async def delete_workflow_config( async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""删除工作流配置 """删除工作流配置
@@ -243,11 +244,11 @@ async def delete_workflow_config(
@router.post("/{app_id}/workflow/validate") @router.post("/{app_id}/workflow/validate")
async def validate_workflow_config( async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)], service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
): ):
"""验证工作流配置 """验证工作流配置
@@ -312,12 +313,12 @@ async def validate_workflow_config(
@router.get("/{app_id}/workflow/executions") @router.get("/{app_id}/workflow/executions")
async def get_workflow_executions( async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)], service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50, limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0 offset: Annotated[int, Query(ge=0)] = 0
): ):
"""获取工作流执行记录列表 """获取工作流执行记录列表
@@ -365,10 +366,10 @@ async def get_workflow_executions(
@router.get("/workflow/executions/{execution_id}") @router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution( async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")], execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""获取工作流执行详情 """获取工作流执行详情
@@ -417,16 +418,14 @@ async def get_workflow_execution(
) )
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run") @router.post("/{app_id}/workflow/run")
async def run_workflow( async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")], app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest, request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""执行工作流 """执行工作流
@@ -487,11 +486,11 @@ async def run_workflow(
""" """
try: try:
async for event in await service.run_workflow( async for event in await service.run_workflow(
app_id=app_id, app_id=app_id,
input_data=input_data, input_data=input_data,
triggered_by=current_user.id, triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True stream=True
): ):
# 提取事件类型和数据 # 提取事件类型和数据
event_type = event.get("event", "message") event_type = event.get("event", "message")
@@ -554,10 +553,10 @@ async def run_workflow(
@router.post("/workflow/executions/{execution_id}/cancel") @router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution( async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")], execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)], db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)], current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)] service: Annotated[WorkflowService, Depends(get_workflow_service)]
): ):
"""取消工作流执行 """取消工作流执行
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
except BusinessException as e: except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}") logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message) return fail(code=e.code, msg=e.message)
except Exception as e: except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True) logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail( return fail(

View File

@@ -7,6 +7,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
class Settings: class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration # API Keys Configuration
@@ -142,7 +143,6 @@ class Settings:
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal) # Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
@@ -150,7 +150,7 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration # Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
@@ -168,6 +168,9 @@ class Settings:
# official environment system version # official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
def get_memory_output_path(self, filename: str = "") -> str: def get_memory_output_path(self, filename: str = "") -> str:
""" """
Get the full path for memory module output files. Get the full path for memory module output files.

View File

@@ -425,15 +425,9 @@ async def Input_Summary(
try: try:
# Extract services from context # Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service") session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service") search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID # Resolve session ID
sessionid = Resolve_username(usermessages) or "" sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '') sessionid = sessionid.replace('call_id_', '')
@@ -539,31 +533,11 @@ async def Input_Summary(
) )
retrieve_info, question, raw_results = "", query, [] retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
# Render template logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend # Emit intermediate output for frontend
return { return {

View File

@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger() logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser: class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -214,6 +211,8 @@ class OpenAPISchemaParser:
if not isinstance(operation, dict): if not isinstance(operation, dict):
continue continue
summary = operation.get("summary", "")
# 生成操作ID # 生成操作ID
operation_id = operation.get("operationId") operation_id = operation.get("operationId")
if not operation_id: if not operation_id:
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
operations[operation_id] = { operations[operation_id] = {
"method": method.upper(), "method": method.upper(),
"path": path, "path": path,
"summary": operation.get("summary", ""), "summary": summary if summary else operation_id,
"description": operation.get("description", ""), "description": operation.get("description", ""),
"parameters": self._extract_parameters(operation), "parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation), "request_body": self._extract_request_body(operation),

View File

@@ -232,7 +232,7 @@ class LangchainAdapter:
# 添加验证约束 # 添加验证约束
if param.enum: if param.enum:
# 枚举值约束 # 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None: if param.minimum is not None:
field_kwargs["ge"] = param.minimum field_kwargs["ge"] = param.minimum
@@ -241,7 +241,7 @@ class LangchainAdapter:
field_kwargs["le"] = param.maximum field_kwargs["le"] = param.maximum
if param.pattern: if param.pattern:
field_kwargs["regex"] = param.pattern field_kwargs["pattern"] = param.pattern
fields[param.name] = Field(**field_kwargs) fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type annotations[param.name] = python_type

View File

@@ -27,20 +27,22 @@ class SimpleMCPClient:
# 确定连接类型 # 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://")) self.is_websocket = server_url.startswith(("ws://", "wss://"))
self.is_sse = "/sse" in server_url.lower()
# 连接状态 # 连接状态
self._websocket = None self._websocket = None
self._session = None self._session = None
self._request_id = 0 self._request_id = 0
self._pending_requests = {} self._pending_requests = {}
self._server_capabilities = {}
self._endpoint_url = None # SSE endpoint URL
self._sse_task = None
async def __aenter__(self): async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect() await self.connect()
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect() await self.disconnect()
async def connect(self): async def connect(self):
@@ -57,47 +59,157 @@ class SimpleMCPClient:
async def disconnect(self): async def disconnect(self):
"""断开连接""" """断开连接"""
try: try:
if self._sse_task:
self._sse_task.cancel()
if self._websocket: if self._websocket:
await self._websocket.close() await self._websocket.close()
self._websocket = None self._websocket = None
if self._session: if self._session:
await self._session.close() await self._session.close()
self._session = None self._session = None
except Exception as e: except Exception as e:
logger.error(f"断开连接失败: {e}") logger.error(f"断开连接失败: {e}")
async def _connect_websocket(self): async def _connect_websocket(self):
"""WebSocket 连接""" """WebSocket 连接"""
headers = self._build_headers() headers = self._build_headers()
self._websocket = await websockets.connect( self._websocket = await websockets.connect(
self.server_url, self.server_url,
extra_headers=headers, extra_headers=headers,
timeout=self.timeout timeout=self.timeout
) )
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages()) asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize() await self._send_initialize()
async def _connect_http(self): async def _connect_http(self):
"""HTTP 连接""" """HTTP 连接"""
headers = self._build_headers() headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout) timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
self._session = aiohttp.ClientSession( if self.is_sse:
headers=headers, await self._initialize_sse_session()
timeout=timeout elif "modelscope.net" in self.server_url:
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
await self._initialize_modelscope_session() await self._initialize_modelscope_session()
async def _initialize_sse_session(self):
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
try:
# 建立 SSE 连接
response = await self._session.get(
self.server_url,
headers={"Accept": "text/event-stream"}
)
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
# 启动 SSE 读取任务
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
# 等待获取 endpoint URL
for _ in range(10):
if self._endpoint_url:
break
await asyncio.sleep(1)
if not self._endpoint_url:
raise MCPConnectionError("未能获取 endpoint URL")
# 发送 initialize 请求到 endpoint
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
init_response = await self._send_sse_request(init_request)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = init_response.get("result", {})
self._server_capabilities = result.get("capabilities", {})
# 发送 initialized 通知
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
async def _read_sse_stream(self, response):
"""读取 SSE 流"""
try:
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('event:'):
continue
if line.startswith('data:'):
data = line[5:].strip() # 去除 'data:' 后的空格
if not data or data == '[DONE]':
continue
try:
# 处理 endpoint 事件(相对路径或绝对路径)
if not self._endpoint_url:
# 如果是相对路径,拼接成完整 URL
if data.startswith('/'):
from urllib.parse import urlparse, urlunparse
parsed = urlparse(self.server_url)
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
else:
self._endpoint_url = data
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
continue
# 处理 message 事件
message = json.loads(data)
request_id = message.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"SSE 流读取错误: {e}")
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""通过 SSE endpoint 发送请求"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
request_id = request["id"]
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
return await asyncio.wait_for(future, timeout=self.timeout)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError("请求超时")
async def _send_sse_notification(self, notification: Dict[str, Any]):
"""发送通知(无需响应)"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self): async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话""" """初始化 ModelScope MCP 会话"""
init_request = { init_request = {
@@ -107,18 +219,12 @@ class SimpleMCPClient:
"params": { "params": {
"protocolVersion": "2024-11-05", "protocolVersion": "2024-11-05",
"capabilities": {"tools": {}}, "capabilities": {"tools": {}},
"clientInfo": { "clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
"name": "MemoryBear",
"version": "1.0.0"
}
} }
} }
try: try:
async with self._session.post( async with self._session.post(self.server_url, json=init_request) as response:
self.server_url,
json=init_request
) as response:
if response.status != 200: if response.status != 200:
error_text = await response.text() error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
@@ -127,21 +233,16 @@ class SimpleMCPClient:
if "error" in init_response: if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}") raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id: if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id}) self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = { initialized_notification = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"method": "notifications/initialized" "method": "notifications/initialized"
} }
async with self._session.post( async with self._session.post(self.server_url, json=initialized_notification):
self.server_url,
json=initialized_notification
) as notif_response:
pass pass
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
@@ -149,12 +250,18 @@ class SimpleMCPClient:
def _build_headers(self) -> Dict[str, str]: def _build_headers(self) -> Dict[str, str]:
"""构建请求头""" """构建请求头"""
# 基础 headers
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json, text/event-stream" "Accept": "application/json, text/event-stream"
} }
# 添加认证头 # 合并 connection_config 中的自定义 headers
custom_headers = self.connection_config.get("headers", {})
if custom_headers:
headers.update(custom_headers)
# 处理认证配置(认证 headers 优先级更高)
auth_config = self.connection_config.get("auth_config", {}) auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none") auth_type = self.connection_config.get("auth_type", "none")
@@ -178,7 +285,7 @@ class SimpleMCPClient:
return headers return headers
async def _send_initialize(self): async def _send_initialize(self):
"""发送初始化消息""" """发送初始化消息WebSocket"""
init_message = { init_message = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": self._get_request_id(), "id": self._get_request_id(),
@@ -186,124 +293,90 @@ class SimpleMCPClient:
"params": { "params": {
"protocolVersion": "2024-11-05", "protocolVersion": "2024-11-05",
"capabilities": {"tools": {}}, "capabilities": {"tools": {}},
"clientInfo": { "clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
"name": "MemoryBear",
"version": "1.0.0"
}
} }
} }
await self._websocket.send(json.dumps(init_message)) await self._websocket.send(json.dumps(init_message))
response = await self._websocket.recv()
response_data = json.loads(response)
# 等待初始化响应 if "error" in response_data:
response = await asyncio.wait_for( raise MCPConnectionError(f"初始化失败: {response_data['error']}")
self._websocket.recv(),
timeout=self.timeout
)
init_response = json.loads(response) result = response_data.get("result", {})
if "error" in init_response: self._server_capabilities = result.get("capabilities", {})
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
await self._websocket.send(json.dumps({
"jsonrpc": "2.0",
"method": "notifications/initialized"
}))
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list"
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
result = response_data.get("result", {})
return result.get("tools", [])
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments}
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
error = response_data["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response_data.get("result", {})
def _get_request_id(self) -> int:
"""生成请求 ID"""
self._request_id += 1
return self._request_id
async def _handle_websocket_messages(self): async def _handle_websocket_messages(self):
"""处理 WebSocket 消息""" """处理 WebSocket 消息"""
try: try:
while self._websocket and not self._websocket.closed: async for message in self._websocket:
try: data = json.loads(message)
message = await self._websocket.recv() request_id = data.get("id")
data = json.loads(message) if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
# 处理响应 if not future.done():
if "id" in data: future.set_result(data)
request_id = str(data["id"]) except ConnectionClosed:
if request_id in self._pending_requests: logger.info("WebSocket 连接已关闭")
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e: except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}") logger.error(f"WebSocket 消息处理错误: {e}")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
await self._websocket.send(json.dumps(request_data))
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
try:
async with self._session.post(
self.server_url,
json=request_data
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"

View File

@@ -74,6 +74,7 @@ class WorkflowExecutor:
初始化的工作流状态 初始化的工作流状态
""" """
user_message = input_data.get("message") or "" user_message = input_data.get("message") or ""
conversation_messages = input_data.get("conv_messages") or []
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value # 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or [] config_variables_list = self.workflow_config.get("variables") or []
@@ -114,7 +115,7 @@ class WorkflowExecutor:
} }
return { return {
"messages": [('user', user_message)], "messages": conversation_messages,
"variables": variables, "variables": variables,
"node_outputs": {}, "node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)

View File

@@ -7,13 +7,13 @@
import asyncio import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from operator import add
from typing import Any from typing import Any
from langchain_core.messages import AnyMessage, AIMessage from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
""" """
# List of messages (append mode) # List of messages (append mode)
messages: Annotated[list[tuple[str, str]], add] messages: list[dict[str, str]]
# Set of loop node IDs, used for assigning values in loop nodes # Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list cycle_nodes: list
@@ -154,7 +154,7 @@ class BaseNode(ABC):
Returns: Returns:
超时时间 超时时间
""" """
return 60 return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60) # return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]: async def run(self, state: WorkflowState) -> dict[str, Any]:
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量 # 返回包装后的输出和运行时变量
return { return {
**wrapped_output, **wrapped_output,
"messages": state["messages"],
"variables": state["variables"], "variables": state["variables"],
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var
@@ -356,6 +357,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = { state_update = {
**final_output, **final_output,
"messages": state["messages"],
"variables": state["variables"], "variables": state["variables"],
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var

View File

@@ -6,7 +6,6 @@ End 节点实现
import logging import logging
import re import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state, strict=False) output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else: else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成" output = "工作流已完成"
# 统计信息(用于日志) # 统计信息(用于日志)
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
"chunk_index": 1, "chunk_index": 1,
"is_suffix": False "is_suffix": False
}) })
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
source_node_id = edge.get("source") source_node_id = edge.get("source")
# Check if the source node is an LLM node # Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []): for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}") logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM: if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id) direct_upstream_llm_nodes.append(source_node_id)
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker # yield completion marker
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix) # Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)") logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts # Collect suffix parts
suffix_parts = [] suffix_parts = []
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False) full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}") logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else: else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}") logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})

View File

@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
"""消息配置""" """消息配置"""
role: str = Field( role: str = Field(
..., default='user',
description="消息角色system, user, assistant" description="消息角色system, user, assistant"
) )
content: str = Field( content: str = Field(
..., default="",
description="消息内容,支持模板变量,如:{{ sys.message }}" description="消息内容,支持模板变量,如:{{ sys.message }}"
) )
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
return v.lower() return v.lower()
class MemoryWindowSetting(BaseModel):
enable: bool = Field(
default=False,
description="启用记忆"
)
enable_window: bool = Field(
default=False,
description="启用记忆窗口"
)
window_size: int = Field(
default=20,
description="记忆窗口大小"
)
class LLMNodeConfig(BaseNodeConfig): class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置 """LLM 节点配置
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="上下文" description="上下文"
) )
memory: MemoryWindowSetting = Field(
...,
description="对话上下文窗口"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,

View File

@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
""" """
# 1. 处理消息格式(优先使用 messages # 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages") messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
for msg_config in messages_config: for msg_config in messages_config:
role = msg_config.get("role", "user").lower() role = msg_config.role.lower()
content_template = msg_config.get("content", "") content_template = msg_config.content
content_template = self._render_context(content_template, state) content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state) content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append(SystemMessage(content=content)) messages.append({"role": "system", "content": content})
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append(HumanMessage(content=content)) messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content)) messages.append({"role": "assistant", "content": content})
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content)) messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages prompt_or_messages = messages
else: else:
# 使用简单的 prompt 格式(向后兼容) # 使用简单的 prompt 格式(向后兼容)
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
return { return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [ "messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} {"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None, ] if isinstance(prompt_or_messages, list) else None,
"config": { "config": {

View File

@@ -3,8 +3,9 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config) self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db: with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state) end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id: if not end_user_id:
raise RuntimeError("End user id is required") raise RuntimeError("End user id is required")
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db: end_user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id: if not end_user_id:
raise RuntimeError("Workspace id is required") raise RuntimeError("End user id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory( write_message_task.delay(
group_id=end_user_id, end_user_id,
message=self._render_template(self.typed_config.message, state), self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id), str(self.typed_config.config_id),
db=db, "neo4j",
storage_type="neo4j", ""
user_rag_memory_id="" )
)
return "success"

View File

@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID") tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置") operation: Optional[str] = Field(default=None, description="工具特定配置")
class ToolOldConfig(BaseModel): class ToolOldConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回") stream: bool = Field(default=False, description="是否流式返回")
class DraftRunRequest(BaseModel): class DraftRunRequest(BaseModel):
"""试运行请求""" """试运行请求"""
message: str = Field(..., description="用户消息") message: str = Field(..., description="用户消息")

View File

@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.db import get_db from app.db import get_db
@@ -59,7 +60,7 @@ class AppChatService:
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt system_prompt = config.system_prompt
if variables: if variables:
@@ -210,7 +211,7 @@ class AppChatService:
# 获取模型配置ID # 获取模型配置ID
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt system_prompt = config.system_prompt
if variables: if variables:
@@ -511,7 +512,6 @@ class AppChatService:
} }
) )
except (GeneratorExit, asyncio.CancelledError): except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出 # 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断") logger.debug("多 Agent 流式聊天被中断")
@@ -537,83 +537,19 @@ class AppChatService:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
workflow_service = WorkflowService(self.db) workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
input_data = {"message":message, "variables": variables, message=message,
"conversation_id": str(conversation_id)} variables=variables,
inconfig = workflow_service.get_workflow_config(app_id) conversation_id=str(conversation_id),
stream=True,
# 2. 创建执行记录 user_id=user_id
execution = workflow_service.create_execution( )
workflow_config_id=inconfig.id, return await workflow_service.run(
app_id=app_id, app_id=app_id,
trigger_type="manual", payload=payload,
triggered_by=None, config=config,
conversation_id=conversation_id, workspace_id=workspace_id,
input_data=input_data
) )
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
)
# 更新执行结果
if result.get("status") == "completed":
workflow_service.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def workflow_chat_stream( async def workflow_chat_stream(
self, self,
@@ -632,62 +568,21 @@ class AppChatService:
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
workflow_service = WorkflowService(self.db) workflow_service = WorkflowService(self.db)
input_data = {"message": message, "variables": variables, payload = DraftRunRequest(
"conversation_id": str(conversation_id)} message=message,
inconfig = workflow_service.get_workflow_config(app_id) variables=variables,
# 2. 创建执行记录 conversation_id=str(conversation_id),
execution = workflow_service.create_execution( stream=True,
workflow_config_id=inconfig.id, user_id=user_id
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
) )
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
):
yield event
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 流式执行工作流
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in workflow_service._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================

View File

@@ -13,12 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import type_mapping, EmotionType, EmotionSubject from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -1198,18 +1200,17 @@ async def analytics_memory_types(
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
统计9种记忆类型的数量和百分比 统计8种记忆类型的数量和百分比
计算规则: 计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity 1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
2. 工作记忆 (WORKING_MEMORY) = chunk + entity 2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk 3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
4. 长期记忆 (LONG_TERM_MEMORY) = entity 4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
5. 性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取) 5. 性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity 6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
7. 情记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取) 7. 情记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取) 8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args: Args:
db: 数据库会话 db: 数据库会话
@@ -1229,7 +1230,6 @@ async def analytics_memory_types(
- PERCEPTUAL_MEMORY: 感知记忆 - PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆 - WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆 - SHORT_TERM_MEMORY: 短期记忆
- LONG_TERM_MEMORY: 长期记忆
- EXPLICIT_MEMORY: 显性记忆 - EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆 - IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆 - EMOTIONAL_MEMORY: 情绪记忆
@@ -1239,40 +1239,78 @@ async def analytics_memory_types(
# 初始化基础服务 # 初始化基础服务
base_service = MemoryBaseService() base_service = MemoryBaseService()
# 定义需要查询的基础节点类型 # 初始化感知记忆服务
node_types = { perceptual_service = MemoryPerceptualService(db)
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk"
}
# 存储每种节点类型的计数 # 获取感知记忆数量
node_counts = {} if end_user_id:
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
perceptual_count = perceptual_stats.get("total", 0)
else:
perceptual_count = 0
# 查询每种节点类型的数量 # 获取工作记忆数量(基于会话数量
for key, node_type in node_types.items(): work_count = 0
if end_user_id: if end_user_id:
query = f""" try:
MATCH (n:{node_type}) conversation_repo = ConversationRepository(db)
conversations = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
limit=100, # 获取更多会话以准确统计
is_activate=True
)
work_count = len(conversations)
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")
work_count = 0
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
implicit_count = 0
if end_user_id:
try:
# 查询 Statement 节点数量
query = """
MATCH (n:Statement)
WHERE n.group_id = $group_id WHERE n.group_id = $group_id
RETURN count(n) as count RETURN count(n) as count
""" """
result = await _neo4j_connector.execute_query(query, group_id=end_user_id) result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
else: statement_count = result[0]["count"] if result and len(result) > 0 else 0
query = f""" # 取三分之一作为隐性记忆数量
MATCH (n:{node_type}) implicit_count = round(statement_count / 3)
RETURN count(n) as count logger.debug(f"隐性记忆数量Statement数量的1/3: {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
""" except Exception as e:
result = await _neo4j_connector.execute_query(query) logger.warning(f"获取Statement数量失败隐性记忆数量设为0: {str(e)}")
implicit_count = 0
# 提取计数结果 # 原有的基于行为习惯的统计方式(已注释)
count = result[0]["count"] if result and len(result) > 0 else 0 # implicit_count = 0
node_counts[key] = count # if end_user_id:
# try:
# implicit_service = ImplicitMemoryService(db, end_user_id)
# behavior_habits = await implicit_service.get_behavior_habits(
# user_id=end_user_id
# )
# implicit_count = len(behavior_habits)
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
# except Exception as e:
# logger.warning(f"获取行为习惯数量失败隐性记忆数量设为0: {str(e)}")
# implicit_count = 0
# 获取各节点类型的数量 # 获取短期记忆数量(基于 /short_term 接口返回的问答对数量
statement_count = node_counts.get("Statement", 0) short_term_count = 0
entity_count = node_counts.get("Entity", 0) if end_user_id:
chunk_count = node_counts.get("Chunk", 0) try:
short_term_service = ShortService(end_user_id)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
short_term_count = len(short_term_data)
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取短期记忆数量失败短期记忆数量设为0: {str(e)}")
short_term_count = 0
# 获取用户的遗忘阈值配置 # 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值 forgetting_threshold = 0.3 # 默认值
@@ -1298,17 +1336,16 @@ async def analytics_memory_types(
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量 # 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id) episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id) explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count) emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold) forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算9种记忆类型的数量使用英文枚举作为key # 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = { memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆 "PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆 "WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆 "SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆) "EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity) "IMPLICIT_MEMORY": implicit_count, # 隐性记忆Statement数量的1/3
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计) "EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆 "EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值) "FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)

View File

@@ -2,12 +2,11 @@
工作流服务层 工作流服务层
""" """
import datetime import datetime
import json
import logging import logging
import uuid import uuid
import datetime
from typing import Any, Annotated, AsyncGenerator from typing import Any, Annotated, AsyncGenerator
from deprecated import deprecated
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.db import get_db, get_db_context from app.db import get_db, get_db_context
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.conversation_repository import MessageRepository
from app.models.conversation_model import Message
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.services.multi_agent_service import convert_uuids_to_str
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
WorkflowExecutionRepository, WorkflowExecutionRepository,
WorkflowNodeExecutionRepository WorkflowNodeExecutionRepository
) )
from app.schemas import DraftRunRequest from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message from app.services.multi_agent_service import convert_uuids_to_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,6 +37,7 @@ class WorkflowService:
self.config_repo = WorkflowConfigRepository(db) self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db) self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.message_repo = MessageRepository(db)
# ==================== 配置管理 ==================== # ==================== 配置管理 ====================
@@ -418,14 +419,13 @@ class WorkflowService:
"""运行工作流 """运行工作流
Args: Args:
workspace_id:
config:
payload:
app_id: 应用 ID app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns: Returns:
执行结果(非流式)或生成器(流式) 执行结果(非流式)
Raises: Raises:
BusinessException: 配置不存在或执行失败时抛出 BusinessException: 配置不存在或执行失败时抛出
@@ -438,7 +438,8 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING, code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}" message=f"工作流配置不存在: app_id={app_id}"
) )
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID # 转换 user_id 为 UUID
triggered_by_uuid = None triggered_by_uuid = None
@@ -461,7 +462,7 @@ class WorkflowService:
workflow_config_id=config.id, workflow_config_id=config.id,
app_id=app_id, app_id=app_id,
trigger_type="manual", trigger_type="manual",
triggered_by=triggered_by_uuid, triggered_by=None,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
input_data=input_data input_data=input_data
) )
@@ -500,8 +501,11 @@ class WorkflowService:
variables = last_state.get("variables", {}) variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {}) conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break break
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow( result = await execute_workflow(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
@@ -517,6 +521,17 @@ class WorkflowService:
"completed", "completed",
output_data=result output_data=result
) )
final_messages = result.get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
else: else:
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
@@ -529,6 +544,7 @@ class WorkflowService:
"execution_id": execution.execution_id, "execution_id": execution.execution_id,
"status": result.get("status"), "status": result.get("status"),
"variables": result.get("variables"), "variables": result.get("variables"),
"messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串) "output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
@@ -559,6 +575,7 @@ class WorkflowService:
"""运行工作流(流式) """运行工作流(流式)
Args: Args:
workspace_id:
app_id: 应用 ID app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等) payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选) config: 存储类型(可选)
@@ -601,7 +618,7 @@ class WorkflowService:
workflow_config_id=config.id, workflow_config_id=config.id,
app_id=app_id, app_id=app_id,
trigger_type="manual", trigger_type="manual",
triggered_by=triggered_by_uuid, triggered_by=None,
conversation_id=conversation_id_uuid, conversation_id=conversation_id_uuid,
input_data=input_data input_data=input_data
) )
@@ -638,17 +655,46 @@ class WorkflowService:
variables = last_state.get("variables", {}) variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {}) conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件) async for event in execute_workflow_stream(
async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=end_user_id user_id=end_user_id
): ):
# 直接转发 executor 的事件(已经是正确的格式) if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=event.get("data")
)
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed":
self.update_execution_status(
execution.execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event yield event
except Exception as e: except Exception as e:
@@ -667,6 +713,8 @@ class WorkflowService:
} }
} }
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow( async def run_workflow(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
@@ -819,6 +867,7 @@ class WorkflowService:
return clean_value(event) return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream( async def _run_workflow_stream(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],

View File

@@ -136,7 +136,8 @@ dependencies = [
"markdown-to-json==2.1.1", "markdown-to-json==2.1.1",
"valkey==6.0.2", "valkey==6.0.2",
"python-calamine>=0.4.0", "python-calamine>=0.4.0",
"xlrd==2.0.2" "xlrd==2.0.2",
"deprecated>=1.3.1",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div> </div>
} }
{/* 消息气泡框 */} {/* 消息气泡框 */}
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, { <div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
// 错误消息样式内容为null且非助手消息 // 错误消息样式内容为null且非助手消息
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null, 'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
// 助手消息样式 // 助手消息样式
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div> </div>
{/* 底部标签(如时间戳、用户名等) */} {/* 底部标签(如时间戳、用户名等) */}
{labelPosition === 'bottom' && {labelPosition === 'bottom' &&
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular"> <div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
{labelFormat(item)} {labelFormat(item)}
</div> </div>
} }

View File

@@ -1265,6 +1265,7 @@ export const en = {
emotionLine: 'Emotion Changes Over Time', emotionLine: 'Emotion Changes Over Time',
interaction: 'Interaction Frequency & Relationship Stages', interaction: 'Interaction Frequency & Relationship Stages',
timelines_memory: 'All', timelines_memory: 'All',
Chunk: 'Chunk',
MemorySummary: 'Long-term Accumulation', MemorySummary: 'Long-term Accumulation',
Statement: 'Emotional Memory', Statement: 'Emotional Memory',
ExtractedEntity: 'Episodic Memory', ExtractedEntity: 'Episodic Memory',
@@ -1786,6 +1787,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
temperature: 'Temperature', temperature: 'Temperature',
max_tokens: 'Max Tokens', max_tokens: 'Max Tokens',
context: 'Context', context: 'Context',
memory: 'Memory',
enable_window: 'Memory Window',
inner: 'Built-in',
}, },
start: { start: {
variables: 'Input Fields', variables: 'Input Fields',

View File

@@ -1343,6 +1343,7 @@ export const zh = {
emotionLine: '情绪随时间变化', emotionLine: '情绪随时间变化',
interaction: '互动频率 & 关系阶段', interaction: '互动频率 & 关系阶段',
timelines_memory: '全部', timelines_memory: '全部',
Chunk: '工作记忆',
MemorySummary: '长期沉淀', MemorySummary: '长期沉淀',
Statement: '情绪记忆', Statement: '情绪记忆',
ExtractedEntity: '情景记忆', ExtractedEntity: '情景记忆',
@@ -1883,6 +1884,9 @@ export const zh = {
temperature: '温度', temperature: '温度',
max_tokens: '最大令牌数', max_tokens: '最大令牌数',
context: '上下文', context: '上下文',
memory: '记忆',
enable_window: '记忆窗口',
inner: '内置',
}, },
start: { start: {
variables: '输入字段', variables: '输入字段',

View File

@@ -176,6 +176,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
if (response?.knowledge_retrieval?.knowledge_bases?.length) { if (response?.knowledge_retrieval?.knowledge_bases?.length) {
getDefaultKnowledgeList(response) getDefaultKnowledgeList(response)
} }
if (response?.tools?.length) {
setToolList(response?.tools)
}
}).finally(() => { }).finally(() => {
setLoading(false) setLoading(false)
}) })

View File

@@ -79,8 +79,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
} }
}, [data]) }, [data])
console.log('toolList', toolList)
const handleAddTool = () => { const handleAddTool = () => {
toolModalRef.current?.handleOpen() toolModalRef.current?.handleOpen()
} }

View File

@@ -259,9 +259,10 @@ const Conversation: FC = () => {
</div> </div>
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]"> <div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
<Chat <Chat
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />} empty={<Empty url={BgImg} className="rb:h-full" size={[320,180]} subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-152px)]" contentClassName="rb:h-[calc(100%-152px)] "
data={chatList} data={chatList}
streamLoading={streamLoading} streamLoading={streamLoading}
loading={loading} loading={loading}
@@ -290,6 +291,7 @@ const Conversation: FC = () => {
</Flex> </Flex>
</Form> </Form>
</Chat> </Chat>
</div>
</div> </div>
</Flex> </Flex>
) )

View File

@@ -1,6 +1,5 @@
import React, { useState, useImperativeHandle, forwardRef, useRef } from 'react'; import { useState, useImperativeHandle, forwardRef, useRef } from 'react';
import { Button, Input, Space, Typography, Tooltip, message, List } from 'antd'; import { Button, Space, List } from 'antd';
import { PlusOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { ChatVariable, AddChatVariableRef } from '../../types'; import type { ChatVariable, AddChatVariableRef } from '../../types';
import type { ChatVariableModalRef } from './types' import type { ChatVariableModalRef } from './types'

View File

@@ -131,7 +131,7 @@ const EditableTable: React.FC<EditableTableProps> = ({
const AddButton = ({ block = false }: { block?: boolean }) => ( const AddButton = ({ block = false }: { block?: boolean }) => (
<Button <Button
type={block ? "dashed" : "text"} type={block ? "dashed" : "text"}
icon={<PlusOutlined />} icon={block ? undefined : <PlusOutlined />}
onClick={() => add(createNewRow())} onClick={() => add(createNewRow())}
size="small" size="small"
block={block} block={block}

View File

@@ -0,0 +1,69 @@
import { type FC } from "react";
import { useTranslation } from 'react-i18next'
import { Form, Row, Col, Divider, Switch, Slider } from 'antd'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import MessageEditor from '../MessageEditor'
const MemoryConfig: FC<{ options: Suggestion[]; parentName: string; }> = ({
options,
parentName
}) => {
const { t } = useTranslation()
const form = Form.useFormInstance();
const values = Form.useWatch([], form) || {}
console.log('MemoryConfig', values)
const handleChangeEnable = (value: boolean) => {
if (value) {
form.setFieldsValue({
memory: {
...form.getFieldValue(parentName),
enable_window: false,
window_size: 20,
messages: "{{sys.message}}"
}
})
}
}
return (
<>
{values?.memory?.enable && <>
<div className="rb:flex rb:items-center rb:justify-between rb:py-1.5 rb:px-2 rb:bg-[#F6F8FC] rb:rounded-md rb:mb-2">
{t('workflow.config.llm.memory')}
<span>{t('workflow.config.llm.inner')}</span>
</div>
<Form.Item layout="horizontal" name={[parentName, 'messages']}>
<MessageEditor
title="USER"
isArray={false}
parentName={[parentName, 'messages']}
options={options}
/>
</Form.Item>
<Divider />
</>}
<Form.Item layout="horizontal" name={[parentName, 'enable']} label={t('workflow.config.llm.memory')}>
<Switch onChange={handleChangeEnable} />
</Form.Item>
{values?.memory?.enable && <>
<Row className="rb:mb-3">
<Col span={10}>
<Form.Item layout="horizontal" name={[parentName, 'enable_window']} noStyle>
<Switch />
</Form.Item>
<span className="rb:ml-2">{t('workflow.config.llm.enable_window')}</span>
</Col>
<Col span={14}>
<Form.Item layout="horizontal" name={[parentName, 'window_size']} noStyle>
<Slider min={1} max={100} step={1} className="rb:my-0!" disabled={!values?.memory?.enable_window} />
</Form.Item>
</Col>
</Row>
</>}
</>
);
};
export default MemoryConfig;

View File

@@ -127,7 +127,7 @@ const MessageEditor: FC<MessageEditor> = ({
</Space> </Space>
); );
})} })}
<Form.Item> <Form.Item noStyle>
<Button type="dashed" onClick={() => handleAdd(add)} block> <Button type="dashed" onClick={() => handleAdd(add)} block>
+{t('workflow.addMessage')} +{t('workflow.addMessage')}
</Button> </Button>

View File

@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
import CycleVarsList from './CycleVarsList' import CycleVarsList from './CycleVarsList'
import AssignmentList from './AssignmentList' import AssignmentList from './AssignmentList'
import ToolConfig from './ToolConfig' import ToolConfig from './ToolConfig'
import MemoryConfig from './MemoryConfig'
// import { calculateVariableList } from './utils/variableListCalculator' // import { calculateVariableList } from './utils/variableListCalculator'
interface PropertiesProps { interface PropertiesProps {
@@ -1230,6 +1231,20 @@ const Properties: FC<PropertiesProps> = ({
</Form.Item> </Form.Item>
) )
} }
if (config.type === 'memoryConfig') {
return (
<Form.Item
key={key}
name={key}
noStyle
>
<MemoryConfig
parentName={key}
options={getFilteredVariableList('llm')}
/>
</Form.Item>
)
}
return ( return (
<Form.Item <Form.Item

View File

@@ -135,6 +135,14 @@ export const nodeLibrary: NodeLibrary[] = [
readonly: true readonly: true
}, },
] ]
},
memory: {
type: 'memoryConfig',
defaultValue: {
enable: false,
enable_window: false,
window_size: 20
}
} }
} }
}, },
@@ -750,10 +758,6 @@ export const outputVariable: { [key: string]: OutputVariable } = {
{ name: "body", type: "string" }, { name: "body", type: "string" },
{ name: "status_code", type: "number" }, { name: "status_code", type: "number" },
], ],
error: [
{ name: "error_message", type: "string" },
{ name: "error_type", type: "string" },
]
}, },
'tool': { 'tool': {
default: [ default: [

View File

@@ -6,7 +6,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
import { register } from '@antv/x6-react-shape'; import { register } from '@antv/x6-react-shape';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant'; import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
import type { WorkflowConfig, NodeProperties } from '../types'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
import type { PortMetadata } from '@antv/x6/lib/model/port'; import type { PortMetadata } from '@antv/x6/lib/model/port';
@@ -35,6 +35,8 @@ export interface UseWorkflowGraphReturn {
copyEvent: () => boolean | void; copyEvent: () => boolean | void;
parseEvent: () => boolean | void; parseEvent: () => boolean | void;
handleSave: (flag?: boolean) => Promise<unknown>; handleSave: (flag?: boolean) => Promise<unknown>;
chatVariables: ChatVariable[];
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
} }
export const edge_color = '#155EEF'; export const edge_color = '#155EEF';
@@ -54,6 +56,7 @@ export const useWorkflowGraph = ({
const [canRedo, setCanRedo] = useState(false); const [canRedo, setCanRedo] = useState(false);
const [isHandMode, setIsHandMode] = useState(false); const [isHandMode, setIsHandMode] = useState(false);
const [config, setConfig] = useState<WorkflowConfig | null>(null); const [config, setConfig] = useState<WorkflowConfig | null>(null);
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
useEffect(() => { useEffect(() => {
getConfig() getConfig()
@@ -63,16 +66,15 @@ export const useWorkflowGraph = ({
getWorkflowConfig(id) getWorkflowConfig(id)
.then(res => { .then(res => {
const { variables, ...rest } = res as WorkflowConfig const { variables, ...rest } = res as WorkflowConfig
setConfig({ const initChatVariables = variables.map(v => {
...rest, const { default: _, ...cleanV } = v
variables: variables.map(v => { return {
const { default: _, ...cleanV } = v ...cleanV,
return { defaultValue: v.default ?? ''
...cleanV, }
defaultValue: v.default ?? ''
}
})
}) })
setChatVariables(initChatVariables)
setConfig({ ...rest, variables: initChatVariables })
}) })
} }
@@ -94,7 +96,17 @@ export const useWorkflowGraph = ({
if (nodeLibraryConfig?.config) { if (nodeLibraryConfig?.config) {
Object.keys(nodeLibraryConfig.config).forEach(key => { Object.keys(nodeLibraryConfig.config).forEach(key => {
if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
const { memory, messages } = config as any;
if (memory?.enable && messages && messages.length > 0) {
const lastMessage = messages[messages.length - 1]
nodeLibraryConfig.config[key].defaultValue = {
...memory,
messages: lastMessage.content
}
nodeLibraryConfig.config.messages.defaultValue.splice(-1, 1)
}
} else if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
const { query, ...rest } = config const { query, ...rest } = config
nodeLibraryConfig.config[key].defaultValue = { nodeLibraryConfig.config[key].defaultValue = {
...rest ...rest
@@ -917,13 +929,13 @@ export const useWorkflowGraph = ({
const params = { const params = {
...config, ...config,
variables: config.variables.map(v => { variables: chatVariables.map(v => {
const { defaultValue, ...cleanV } = v const { defaultValue, ...cleanV } = v
return { return {
...cleanV, ...cleanV,
default: defaultValue ?? '' default: defaultValue ?? ''
} }
}), }),
nodes: nodes.map((node: Node) => { nodes: nodes.map((node: Node) => {
const data = node.getData(); const data = node.getData();
const position = node.getPosition(); const position = node.getPosition();
@@ -931,7 +943,15 @@ export const useWorkflowGraph = ({
if (data.config) { if (data.config) {
Object.keys(data.config).forEach(key => { Object.keys(data.config).forEach(key => {
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') { if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
const { messages, ...rest } = data.config[key].defaultValue
let memoryMessage = { role: 'USER', content: data.config[key].defaultValue.messages }
itemConfig = {
...itemConfig,
messages: rest.enable ? [...itemConfig.messages, memoryMessage] : itemConfig.messages,
memory: { ...rest },
}
} else if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue
if (data.config.group.defaultValue) { if (data.config.group.defaultValue) {
data.config[key].defaultValue.map((vo: any) => { data.config[key].defaultValue.map((vo: any) => {
@@ -1077,5 +1097,7 @@ export const useWorkflowGraph = ({
copyEvent, copyEvent,
parseEvent, parseEvent,
handleSave, handleSave,
chatVariables,
setChatVariables
}; };
}; };

View File

@@ -8,7 +8,7 @@ import PortClickHandler from './components/PortClickHandler';
import { useWorkflowGraph } from './hooks/useWorkflowGraph'; import { useWorkflowGraph } from './hooks/useWorkflowGraph';
import type { WorkflowRef } from '@/views/ApplicationConfig/types' import type { WorkflowRef } from '@/views/ApplicationConfig/types'
import Chat from './components/Chat/Chat'; import Chat from './components/Chat/Chat';
import type { ChatRef, AddChatVariableRef, ChatVariable } from './types' import type { ChatRef, AddChatVariableRef } from './types'
import arrowIcon from '@/assets/images/workflow/arrow.png' import arrowIcon from '@/assets/images/workflow/arrow.png'
import AddChatVariable from './components/AddChatVariable'; import AddChatVariable from './components/AddChatVariable';
@@ -21,7 +21,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
// 使用自定义Hook初始化工作流图 // 使用自定义Hook初始化工作流图
const { const {
config, config,
setConfig,
graphRef, graphRef,
selectedNode, selectedNode,
setSelectedNode, setSelectedNode,
@@ -38,6 +37,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
copyEvent, copyEvent,
parseEvent, parseEvent,
handleSave, handleSave,
chatVariables,
setChatVariables
} = useWorkflowGraph({ containerRef, miniMapRef }); } = useWorkflowGraph({ containerRef, miniMapRef });
const onDragOver = (event: React.DragEvent) => { const onDragOver = (event: React.DragEvent) => {
@@ -52,15 +53,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
const addVariable = () => { const addVariable = () => {
addChatVariableRef.current?.handleOpen() addChatVariableRef.current?.handleOpen()
} }
const handleUpdateChatVariable = (variables: ChatVariable[]) => {
setConfig(prev => {
if (!prev) return null
return {
...prev,
variables
}
})
}
useImperativeHandle(ref, () => ({ useImperativeHandle(ref, () => ({
handleSave, handleSave,
@@ -125,8 +117,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
<AddChatVariable <AddChatVariable
ref={addChatVariableRef} ref={addChatVariableRef}
variables={config?.variables} variables={chatVariables}
onChange={handleUpdateChatVariable} onChange={setChatVariables}
/> />
</div> </div>
); );