Merge branch 'refs/heads/develop' into fix/memory_bug_fix

# Conflicts:
#	api/app/services/user_memory_service.py
This commit is contained in:
lixinyue
2026-01-14 18:25:47 +08:00
31 changed files with 731 additions and 520 deletions

View File

@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
@@ -96,6 +96,7 @@ async def create_workflow_config(
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
@@ -199,10 +200,10 @@ async def create_workflow_config(
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
@@ -243,11 +244,11 @@ async def delete_workflow_config(
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
@@ -312,12 +313,12 @@ async def validate_workflow_config(
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
@@ -365,10 +366,10 @@ async def get_workflow_executions(
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
@@ -417,16 +418,14 @@ async def get_workflow_execution(
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
@@ -487,22 +486,22 @@ async def run_workflow(
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
@@ -554,10 +553,10 @@ async def run_workflow(
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
return fail(code=e.code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(

View File

@@ -7,17 +7,18 @@ from dotenv import load_dotenv
load_dotenv()
class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
# Database configuration (Postgres)
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
@@ -37,7 +38,7 @@ class Settings:
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
@@ -48,7 +49,7 @@ class Settings:
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
# Xinference configuration
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
@@ -57,17 +58,17 @@ class Settings:
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
# LLM Request Configuration
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
# JWT Token Configuration
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
@@ -86,19 +87,19 @@ class Settings:
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
# ========================================================================
# Superuser settings (internal defaults)
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
# Generic File Upload (internal)
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
@@ -123,7 +124,7 @@ class Settings:
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
# Sensitive Data Filtering
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
@@ -142,7 +143,6 @@ class Settings:
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
@@ -150,15 +150,15 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
@@ -167,7 +167,10 @@ class Settings:
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.
@@ -182,7 +185,7 @@ class Settings:
if filename:
return str(base_path / filename)
return str(base_path)
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.

View File

@@ -425,15 +425,9 @@ async def Input_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
@@ -539,31 +533,11 @@ async def Input_Summary(
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {

View File

@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
if not isinstance(operation, dict):
continue
summary = operation.get("summary", "")
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"summary": summary if summary else operation_id,
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),

View File

@@ -232,7 +232,7 @@ class LangchainAdapter:
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
@@ -241,7 +241,7 @@ class LangchainAdapter:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["regex"] = param.pattern
field_kwargs["pattern"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type

View File

@@ -27,20 +27,22 @@ class SimpleMCPClient:
# 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://"))
self.is_sse = "/sse" in server_url.lower()
# 连接状态
self._websocket = None
self._session = None
self._request_id = 0
self._pending_requests = {}
self._server_capabilities = {}
self._endpoint_url = None # SSE endpoint URL
self._sse_task = None
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
async def connect(self):
@@ -57,47 +59,157 @@ class SimpleMCPClient:
async def disconnect(self):
"""断开连接"""
try:
if self._sse_task:
self._sse_task.cancel()
if self._websocket:
await self._websocket.close()
self._websocket = None
if self._session:
await self._session.close()
self._session = None
except Exception as e:
logger.error(f"断开连接失败: {e}")
async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
self._session = aiohttp.ClientSession(
headers=headers,
timeout=timeout
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
if self.is_sse:
await self._initialize_sse_session()
elif "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_sse_session(self):
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
try:
# 建立 SSE 连接
response = await self._session.get(
self.server_url,
headers={"Accept": "text/event-stream"}
)
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
# 启动 SSE 读取任务
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
# 等待获取 endpoint URL
for _ in range(10):
if self._endpoint_url:
break
await asyncio.sleep(1)
if not self._endpoint_url:
raise MCPConnectionError("未能获取 endpoint URL")
# 发送 initialize 请求到 endpoint
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
init_response = await self._send_sse_request(init_request)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = init_response.get("result", {})
self._server_capabilities = result.get("capabilities", {})
# 发送 initialized 通知
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
async def _read_sse_stream(self, response):
"""读取 SSE 流"""
try:
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('event:'):
continue
if line.startswith('data:'):
data = line[5:].strip() # 去除 'data:' 后的空格
if not data or data == '[DONE]':
continue
try:
# 处理 endpoint 事件(相对路径或绝对路径)
if not self._endpoint_url:
# 如果是相对路径,拼接成完整 URL
if data.startswith('/'):
from urllib.parse import urlparse, urlunparse
parsed = urlparse(self.server_url)
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
else:
self._endpoint_url = data
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
continue
# 处理 message 事件
message = json.loads(data)
request_id = message.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"SSE 流读取错误: {e}")
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""通过 SSE endpoint 发送请求"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
request_id = request["id"]
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
return await asyncio.wait_for(future, timeout=self.timeout)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError("请求超时")
async def _send_sse_notification(self, notification: Dict[str, Any]):
"""发送通知(无需响应)"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
@@ -107,18 +219,12 @@ class SimpleMCPClient:
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
try:
async with self._session.post(
self.server_url,
json=init_request
) as response:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
@@ -127,21 +233,16 @@ class SimpleMCPClient:
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(
self.server_url,
json=initialized_notification
) as notif_response:
async with self._session.post(self.server_url, json=initialized_notification):
pass
except aiohttp.ClientError as e:
@@ -149,12 +250,18 @@ class SimpleMCPClient:
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
# 基础 headers
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 添加认证头
# 合并 connection_config 中的自定义 headers
custom_headers = self.connection_config.get("headers", {})
if custom_headers:
headers.update(custom_headers)
# 处理认证配置(认证 headers 优先级更高)
auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none")
@@ -178,7 +285,7 @@ class SimpleMCPClient:
return headers
async def _send_initialize(self):
"""发送初始化消息"""
"""发送初始化消息WebSocket"""
init_message = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
@@ -186,124 +293,90 @@ class SimpleMCPClient:
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
await self._websocket.send(json.dumps(init_message))
response = await self._websocket.recv()
response_data = json.loads(response)
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.timeout
)
if "error" in response_data:
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
init_response = json.loads(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = response_data.get("result", {})
self._server_capabilities = result.get("capabilities", {})
await self._websocket.send(json.dumps({
"jsonrpc": "2.0",
"method": "notifications/initialized"
}))
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list"
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
result = response_data.get("result", {})
return result.get("tools", [])
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments}
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
error = response_data["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response_data.get("result", {})
def _get_request_id(self) -> int:
"""生成请求 ID"""
self._request_id += 1
return self._request_id
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
data = json.loads(message)
# 处理响应
if "id" in data:
request_id = str(data["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
async for message in self._websocket:
data = json.loads(message)
request_id = data.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
logger.info("WebSocket 连接已关闭")
except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
await self._websocket.send(json.dumps(request_data))
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
try:
async with self._session.post(
self.server_url,
json=request_data
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
logger.error(f"WebSocket 消息处理错误: {e}")

View File

@@ -74,6 +74,7 @@ class WorkflowExecutor:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_messages = input_data.get("conv_messages") or []
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or []
@@ -114,7 +115,7 @@ class WorkflowExecutor:
}
return {
"messages": [('user', user_message)],
"messages": conversation_messages,
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)

View File

@@ -7,13 +7,13 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from operator import add
from typing import Any
from langchain_core.messages import AnyMessage, AIMessage
from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# List of messages (append mode)
messages: Annotated[list[tuple[str, str]], add]
messages: list[dict[str, str]]
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
@@ -154,7 +154,7 @@ class BaseNode(ABC):
Returns:
超时时间
"""
return 60
return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
@@ -356,6 +357,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var

View File

@@ -6,7 +6,6 @@ End 节点实现
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成"
# 统计信息(用于日志)
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
"chunk_index": 1,
"is_suffix": False
})
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output}
return
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
source_node_id = edge.get("source")
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker
yield {"__final__": True, "result": output}
return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts
suffix_parts = []
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})

View File

@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
default='user',
description="消息角色system, user, assistant"
)
content: str = Field(
...,
default="",
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
return v.lower()
class MemoryWindowSetting(BaseModel):
enable: bool = Field(
default=False,
description="启用记忆"
)
enable_window: bool = Field(
default=False,
description="启用记忆窗口"
)
window_size: int = Field(
default=20,
description="记忆窗口大小"
)
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="上下文"
)
memory: MemoryWindowSetting = Field(
...,
description="对话上下文窗口"
)
# 简单模式
prompt: str | None = Field(
default=None,

View File

@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
messages_config = self.typed_config.messages
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
messages.append({"role": "system", "content": content})
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
messages.append({"role": "assistant", "content": content})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
{"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {

View File

@@ -3,8 +3,9 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context
from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
class MemoryReadNode(BaseNode):
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id),
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)
write_message_task.delay(
end_user_id,
self._render_template(self.typed_config.message, state),
str(self.typed_config.config_id),
"neo4j",
""
)
return "success"

View File

@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
class DraftRunRequest(BaseModel):
"""试运行请求"""
message: str = Field(..., description="用户消息")

View File

@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
@@ -59,7 +60,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -210,7 +211,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -511,7 +512,6 @@ class AppChatService:
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")
@@ -537,83 +537,19 @@ class AppChatService:
) -> Dict[str, Any]:
"""聊天(非流式)"""
workflow_service = WorkflowService(self.db)
input_data = {"message":message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id
)
return await workflow_service.run(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
)
# 更新执行结果
if result.get("status") == "completed":
workflow_service.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def workflow_chat_stream(
self,
@@ -632,62 +568,21 @@ class AppChatService:
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
workflow_service = WorkflowService(self.db)
input_data = {"message": message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id
)
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
):
yield event
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 流式执行工作流
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in workflow_service._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ====================

View File

@@ -13,12 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import type_mapping, EmotionType, EmotionSubject
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -1198,18 +1200,17 @@ async def analytics_memory_types(
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计9种记忆类型的数量和百分比
统计8种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
4. 长期记忆 (LONG_TERM_MEMORY) = entity
5. 性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
7. 情记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
5. 性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
7. 情记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args:
db: 数据库会话
@@ -1229,7 +1230,6 @@ async def analytics_memory_types(
- PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆
- LONG_TERM_MEMORY: 长期记忆
- EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
@@ -1239,40 +1239,78 @@ async def analytics_memory_types(
# 初始化基础服务
base_service = MemoryBaseService()
# 定义需要查询的基础节点类型
node_types = {
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk"
}
# 初始化感知记忆服务
perceptual_service = MemoryPerceptualService(db)
# 存储每种节点类型的计数
node_counts = {}
# 获取感知记忆数量
if end_user_id:
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
perceptual_count = perceptual_stats.get("total", 0)
else:
perceptual_count = 0
# 查询每种节点类型的数量
for key, node_type in node_types.items():
if end_user_id:
query = f"""
MATCH (n:{node_type})
# 获取工作记忆数量(基于会话数量
work_count = 0
if end_user_id:
try:
conversation_repo = ConversationRepository(db)
conversations = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
limit=100, # 获取更多会话以准确统计
is_activate=True
)
work_count = len(conversations)
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")
work_count = 0
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
implicit_count = 0
if end_user_id:
try:
# 查询 Statement 节点数量
query = """
MATCH (n:Statement)
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query)
# 提取计数结果
count = result[0]["count"] if result and len(result) > 0 else 0
node_counts[key] = count
statement_count = result[0]["count"] if result and len(result) > 0 else 0
# 取三分之一作为隐性记忆数量
implicit_count = round(statement_count / 3)
logger.debug(f"隐性记忆数量Statement数量的1/3: {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取Statement数量失败隐性记忆数量设为0: {str(e)}")
implicit_count = 0
# 获取各节点类型的数量
statement_count = node_counts.get("Statement", 0)
entity_count = node_counts.get("Entity", 0)
chunk_count = node_counts.get("Chunk", 0)
# 原有的基于行为习惯的统计方式(已注释)
# implicit_count = 0
# if end_user_id:
# try:
# implicit_service = ImplicitMemoryService(db, end_user_id)
# behavior_habits = await implicit_service.get_behavior_habits(
# user_id=end_user_id
# )
# implicit_count = len(behavior_habits)
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
# except Exception as e:
# logger.warning(f"获取行为习惯数量失败隐性记忆数量设为0: {str(e)}")
# implicit_count = 0
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
short_term_count = 0
if end_user_id:
try:
short_term_service = ShortService(end_user_id)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
short_term_count = len(short_term_data)
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取短期记忆数量失败短期记忆数量设为0: {str(e)}")
short_term_count = 0
# 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值
@@ -1298,17 +1336,16 @@ async def analytics_memory_types(
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算9种记忆类型的数量使用英文枚举作为key
# 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆Statement数量的1/3
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)

View File

@@ -2,12 +2,11 @@
工作流服务层
"""
import datetime
import json
import logging
import uuid
import datetime
from typing import Any, Annotated, AsyncGenerator
from deprecated import deprecated
from fastapi import Depends
from sqlalchemy.orm import Session
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db, get_db_context
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.conversation_repository import MessageRepository
from app.models.conversation_model import Message
from app.repositories.end_user_repository import EndUserRepository
from app.services.multi_agent_service import convert_uuids_to_str
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message
from app.services.multi_agent_service import convert_uuids_to_str
logger = logging.getLogger(__name__)
@@ -37,6 +37,7 @@ class WorkflowService:
self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.message_repo = MessageRepository(db)
# ==================== 配置管理 ====================
@@ -418,14 +419,13 @@ class WorkflowService:
"""运行工作流
Args:
workspace_id:
config:
payload:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
执行结果(非流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
@@ -438,7 +438,8 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
@@ -461,7 +462,7 @@ class WorkflowService:
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
triggered_by=None,
conversation_id=conversation_id_uuid,
input_data=input_data
)
@@ -500,8 +501,11 @@ class WorkflowService:
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
@@ -517,6 +521,17 @@ class WorkflowService:
"completed",
output_data=result
)
final_messages = result.get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
else:
self.update_execution_status(
execution.execution_id,
@@ -529,6 +544,7 @@ class WorkflowService:
"execution_id": execution.execution_id,
"status": result.get("status"),
"variables": result.get("variables"),
"messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
@@ -559,6 +575,7 @@ class WorkflowService:
"""运行工作流(流式)
Args:
workspace_id:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
@@ -601,7 +618,7 @@ class WorkflowService:
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
triggered_by=None,
conversation_id=conversation_id_uuid,
input_data=input_data
)
@@ -638,17 +655,46 @@ class WorkflowService:
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in self._run_workflow_stream(
async for event in execute_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=end_user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=event.get("data")
)
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed":
self.update_execution_status(
execution.execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
except Exception as e:
@@ -667,6 +713,8 @@ class WorkflowService:
}
}
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow(
self,
app_id: uuid.UUID,
@@ -819,6 +867,7 @@ class WorkflowService:
return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],

View File

@@ -136,7 +136,8 @@ dependencies = [
"markdown-to-json==2.1.1",
"valkey==6.0.2",
"python-calamine>=0.4.0",
"xlrd==2.0.2"
"xlrd==2.0.2",
"deprecated>=1.3.1",
]
[tool.pytest.ini_options]

View File

@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div>
}
{/* 消息气泡框 */}
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, {
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
// 错误消息样式内容为null且非助手消息
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
// 助手消息样式
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div>
{/* 底部标签(如时间戳、用户名等) */}
{labelPosition === 'bottom' &&
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular">
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
{labelFormat(item)}
</div>
}

View File

@@ -1265,6 +1265,7 @@ export const en = {
emotionLine: 'Emotion Changes Over Time',
interaction: 'Interaction Frequency & Relationship Stages',
timelines_memory: 'All',
Chunk: 'Chunk',
MemorySummary: 'Long-term Accumulation',
Statement: 'Emotional Memory',
ExtractedEntity: 'Episodic Memory',
@@ -1786,6 +1787,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
temperature: 'Temperature',
max_tokens: 'Max Tokens',
context: 'Context',
memory: 'Memory',
enable_window: 'Memory Window',
inner: 'Built-in',
},
start: {
variables: 'Input Fields',

View File

@@ -1343,6 +1343,7 @@ export const zh = {
emotionLine: '情绪随时间变化',
interaction: '互动频率 & 关系阶段',
timelines_memory: '全部',
Chunk: '工作记忆',
MemorySummary: '长期沉淀',
Statement: '情绪记忆',
ExtractedEntity: '情景记忆',
@@ -1883,6 +1884,9 @@ export const zh = {
temperature: '温度',
max_tokens: '最大令牌数',
context: '上下文',
memory: '记忆',
enable_window: '记忆窗口',
inner: '内置',
},
start: {
variables: '输入字段',

View File

@@ -176,6 +176,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
if (response?.knowledge_retrieval?.knowledge_bases?.length) {
getDefaultKnowledgeList(response)
}
if (response?.tools?.length) {
setToolList(response?.tools)
}
}).finally(() => {
setLoading(false)
})

View File

@@ -79,8 +79,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
}
}, [data])
console.log('toolList', toolList)
const handleAddTool = () => {
toolModalRef.current?.handleOpen()
}

View File

@@ -259,9 +259,10 @@ const Conversation: FC = () => {
</div>
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
<Chat
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-152px)]"
empty={<Empty url={BgImg} className="rb:h-full" size={[320,180]} subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-152px)] "
data={chatList}
streamLoading={streamLoading}
loading={loading}
@@ -290,6 +291,7 @@ const Conversation: FC = () => {
</Flex>
</Form>
</Chat>
</div>
</div>
</Flex>
)

View File

@@ -1,6 +1,5 @@
import React, { useState, useImperativeHandle, forwardRef, useRef } from 'react';
import { Button, Input, Space, Typography, Tooltip, message, List } from 'antd';
import { PlusOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons';
import { useState, useImperativeHandle, forwardRef, useRef } from 'react';
import { Button, Space, List } from 'antd';
import { useTranslation } from 'react-i18next';
import type { ChatVariable, AddChatVariableRef } from '../../types';
import type { ChatVariableModalRef } from './types'

View File

@@ -131,7 +131,7 @@ const EditableTable: React.FC<EditableTableProps> = ({
const AddButton = ({ block = false }: { block?: boolean }) => (
<Button
type={block ? "dashed" : "text"}
icon={<PlusOutlined />}
icon={block ? undefined : <PlusOutlined />}
onClick={() => add(createNewRow())}
size="small"
block={block}

View File

@@ -0,0 +1,69 @@
import { type FC } from "react";
import { useTranslation } from 'react-i18next'
import { Form, Row, Col, Divider, Switch, Slider } from 'antd'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import MessageEditor from '../MessageEditor'
const MemoryConfig: FC<{ options: Suggestion[]; parentName: string; }> = ({
options,
parentName
}) => {
const { t } = useTranslation()
const form = Form.useFormInstance();
const values = Form.useWatch([], form) || {}
console.log('MemoryConfig', values)
const handleChangeEnable = (value: boolean) => {
if (value) {
form.setFieldsValue({
memory: {
...form.getFieldValue(parentName),
enable_window: false,
window_size: 20,
messages: "{{sys.message}}"
}
})
}
}
return (
<>
{values?.memory?.enable && <>
<div className="rb:flex rb:items-center rb:justify-between rb:py-1.5 rb:px-2 rb:bg-[#F6F8FC] rb:rounded-md rb:mb-2">
{t('workflow.config.llm.memory')}
<span>{t('workflow.config.llm.inner')}</span>
</div>
<Form.Item layout="horizontal" name={[parentName, 'messages']}>
<MessageEditor
title="USER"
isArray={false}
parentName={[parentName, 'messages']}
options={options}
/>
</Form.Item>
<Divider />
</>}
<Form.Item layout="horizontal" name={[parentName, 'enable']} label={t('workflow.config.llm.memory')}>
<Switch onChange={handleChangeEnable} />
</Form.Item>
{values?.memory?.enable && <>
<Row className="rb:mb-3">
<Col span={10}>
<Form.Item layout="horizontal" name={[parentName, 'enable_window']} noStyle>
<Switch />
</Form.Item>
<span className="rb:ml-2">{t('workflow.config.llm.enable_window')}</span>
</Col>
<Col span={14}>
<Form.Item layout="horizontal" name={[parentName, 'window_size']} noStyle>
<Slider min={1} max={100} step={1} className="rb:my-0!" disabled={!values?.memory?.enable_window} />
</Form.Item>
</Col>
</Row>
</>}
</>
);
};
export default MemoryConfig;

View File

@@ -127,7 +127,7 @@ const MessageEditor: FC<MessageEditor> = ({
</Space>
);
})}
<Form.Item>
<Form.Item noStyle>
<Button type="dashed" onClick={() => handleAdd(add)} block>
+{t('workflow.addMessage')}
</Button>

View File

@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
import CycleVarsList from './CycleVarsList'
import AssignmentList from './AssignmentList'
import ToolConfig from './ToolConfig'
import MemoryConfig from './MemoryConfig'
// import { calculateVariableList } from './utils/variableListCalculator'
interface PropertiesProps {
@@ -1230,6 +1231,20 @@ const Properties: FC<PropertiesProps> = ({
</Form.Item>
)
}
if (config.type === 'memoryConfig') {
return (
<Form.Item
key={key}
name={key}
noStyle
>
<MemoryConfig
parentName={key}
options={getFilteredVariableList('llm')}
/>
</Form.Item>
)
}
return (
<Form.Item

View File

@@ -135,6 +135,14 @@ export const nodeLibrary: NodeLibrary[] = [
readonly: true
},
]
},
memory: {
type: 'memoryConfig',
defaultValue: {
enable: false,
enable_window: false,
window_size: 20
}
}
}
},
@@ -750,10 +758,6 @@ export const outputVariable: { [key: string]: OutputVariable } = {
{ name: "body", type: "string" },
{ name: "status_code", type: "number" },
],
error: [
{ name: "error_message", type: "string" },
{ name: "error_type", type: "string" },
]
},
'tool': {
default: [

View File

@@ -6,7 +6,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
import { register } from '@antv/x6-react-shape';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
import type { WorkflowConfig, NodeProperties } from '../types';
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
import type { PortMetadata } from '@antv/x6/lib/model/port';
@@ -35,6 +35,8 @@ export interface UseWorkflowGraphReturn {
copyEvent: () => boolean | void;
parseEvent: () => boolean | void;
handleSave: (flag?: boolean) => Promise<unknown>;
chatVariables: ChatVariable[];
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
}
export const edge_color = '#155EEF';
@@ -54,6 +56,7 @@ export const useWorkflowGraph = ({
const [canRedo, setCanRedo] = useState(false);
const [isHandMode, setIsHandMode] = useState(false);
const [config, setConfig] = useState<WorkflowConfig | null>(null);
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
useEffect(() => {
getConfig()
@@ -63,16 +66,15 @@ export const useWorkflowGraph = ({
getWorkflowConfig(id)
.then(res => {
const { variables, ...rest } = res as WorkflowConfig
setConfig({
...rest,
variables: variables.map(v => {
const { default: _, ...cleanV } = v
return {
...cleanV,
defaultValue: v.default ?? ''
}
})
const initChatVariables = variables.map(v => {
const { default: _, ...cleanV } = v
return {
...cleanV,
defaultValue: v.default ?? ''
}
})
setChatVariables(initChatVariables)
setConfig({ ...rest, variables: initChatVariables })
})
}
@@ -94,7 +96,17 @@ export const useWorkflowGraph = ({
if (nodeLibraryConfig?.config) {
Object.keys(nodeLibraryConfig.config).forEach(key => {
if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
const { memory, messages } = config as any;
if (memory?.enable && messages && messages.length > 0) {
const lastMessage = messages[messages.length - 1]
nodeLibraryConfig.config[key].defaultValue = {
...memory,
messages: lastMessage.content
}
nodeLibraryConfig.config.messages.defaultValue.splice(-1, 1)
}
} else if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
const { query, ...rest } = config
nodeLibraryConfig.config[key].defaultValue = {
...rest
@@ -917,13 +929,13 @@ export const useWorkflowGraph = ({
const params = {
...config,
variables: config.variables.map(v => {
const { defaultValue, ...cleanV } = v
return {
...cleanV,
default: defaultValue ?? ''
}
}),
variables: chatVariables.map(v => {
const { defaultValue, ...cleanV } = v
return {
...cleanV,
default: defaultValue ?? ''
}
}),
nodes: nodes.map((node: Node) => {
const data = node.getData();
const position = node.getPosition();
@@ -931,7 +943,15 @@ export const useWorkflowGraph = ({
if (data.config) {
Object.keys(data.config).forEach(key => {
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
const { messages, ...rest } = data.config[key].defaultValue
let memoryMessage = { role: 'USER', content: data.config[key].defaultValue.messages }
itemConfig = {
...itemConfig,
messages: rest.enable ? [...itemConfig.messages, memoryMessage] : itemConfig.messages,
memory: { ...rest },
}
} else if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue
if (data.config.group.defaultValue) {
data.config[key].defaultValue.map((vo: any) => {
@@ -1077,5 +1097,7 @@ export const useWorkflowGraph = ({
copyEvent,
parseEvent,
handleSave,
chatVariables,
setChatVariables
};
};

View File

@@ -8,7 +8,7 @@ import PortClickHandler from './components/PortClickHandler';
import { useWorkflowGraph } from './hooks/useWorkflowGraph';
import type { WorkflowRef } from '@/views/ApplicationConfig/types'
import Chat from './components/Chat/Chat';
import type { ChatRef, AddChatVariableRef, ChatVariable } from './types'
import type { ChatRef, AddChatVariableRef } from './types'
import arrowIcon from '@/assets/images/workflow/arrow.png'
import AddChatVariable from './components/AddChatVariable';
@@ -21,7 +21,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
// 使用自定义Hook初始化工作流图
const {
config,
setConfig,
graphRef,
selectedNode,
setSelectedNode,
@@ -38,6 +37,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
copyEvent,
parseEvent,
handleSave,
chatVariables,
setChatVariables
} = useWorkflowGraph({ containerRef, miniMapRef });
const onDragOver = (event: React.DragEvent) => {
@@ -52,15 +53,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
const addVariable = () => {
addChatVariableRef.current?.handleOpen()
}
const handleUpdateChatVariable = (variables: ChatVariable[]) => {
setConfig(prev => {
if (!prev) return null
return {
...prev,
variables
}
})
}
useImperativeHandle(ref, () => ({
handleSave,
@@ -125,8 +117,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
<AddChatVariable
ref={addChatVariableRef}
variables={config?.variables}
onChange={handleUpdateChatVariable}
variables={chatVariables}
onChange={setChatVariables}
/>
</div>
);