Merge branch 'refs/heads/develop' into fix/memory_bug_fix
# Conflicts: # api/app/services/user_memory_service.py
This commit is contained in:
@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
@@ -96,6 +96,7 @@ async def create_workflow_config(
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
@@ -199,10 +200,10 @@ async def create_workflow_config(
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
@@ -243,11 +244,11 @@ async def delete_workflow_config(
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
@@ -312,12 +313,12 @@ async def validate_workflow_config(
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
@@ -365,10 +366,10 @@ async def get_workflow_executions(
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
@@ -417,16 +418,14 @@ async def get_workflow_execution(
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
@@ -487,22 +486,22 @@ async def run_workflow(
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
@@ -554,10 +553,10 @@ async def run_workflow(
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
|
||||
@@ -7,17 +7,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
|
||||
# Database configuration (Postgres)
|
||||
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||
@@ -37,7 +38,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||
@@ -48,7 +49,7 @@ class Settings:
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||
|
||||
|
||||
# Xinference configuration
|
||||
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||
|
||||
@@ -57,17 +58,17 @@ class Settings:
|
||||
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||
|
||||
|
||||
# LLM Request Configuration
|
||||
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||
|
||||
|
||||
# JWT Token Configuration
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
@@ -86,19 +87,19 @@ class Settings:
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
# Superuser settings (internal defaults)
|
||||
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||
|
||||
|
||||
# Generic File Upload (internal)
|
||||
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||
@@ -123,7 +124,7 @@ class Settings:
|
||||
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||
|
||||
|
||||
# Sensitive Data Filtering
|
||||
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||
|
||||
@@ -142,7 +143,6 @@ class Settings:
|
||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
@@ -150,15 +150,15 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
@@ -167,7 +167,10 @@ class Settings:
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
@@ -182,7 +185,7 @@ class Settings:
|
||||
if filename:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -425,15 +425,9 @@ async def Input_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
@@ -539,31 +533,11 @@ async def Input_Summary(
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
# Return retrieved information directly without LLM processing
|
||||
# Use the raw retrieved info as the answer
|
||||
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
|
||||
@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
|
||||
summary = operation.get("summary", "")
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"summary": summary if summary else operation_id,
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
|
||||
@@ -232,7 +232,7 @@ class LangchainAdapter:
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
@@ -241,7 +241,7 @@ class LangchainAdapter:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["regex"] = param.pattern
|
||||
field_kwargs["pattern"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
|
||||
@@ -27,20 +27,22 @@ class SimpleMCPClient:
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
self.is_sse = "/sse" in server_url.lower()
|
||||
|
||||
# 连接状态
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
self._request_id = 0
|
||||
self._pending_requests = {}
|
||||
self._server_capabilities = {}
|
||||
self._endpoint_url = None # SSE endpoint URL
|
||||
self._sse_task = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
@@ -57,47 +59,157 @@ class SimpleMCPClient:
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 启动消息处理
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
|
||||
# 发送初始化消息
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||
if "modelscope.net" in self.server_url:
|
||||
if self.is_sse:
|
||||
await self._initialize_sse_session()
|
||||
elif "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_sse_session(self):
|
||||
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||
try:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(
|
||||
self.server_url,
|
||||
headers={"Accept": "text/event-stream"}
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
# 启动 SSE 读取任务
|
||||
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
|
||||
|
||||
# 等待获取 endpoint URL
|
||||
for _ in range(10):
|
||||
if self._endpoint_url:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("未能获取 endpoint URL")
|
||||
|
||||
# 发送 initialize 请求到 endpoint
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
init_response = await self._send_sse_request(init_request)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
result = init_response.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
async def _read_sse_stream(self, response):
|
||||
"""读取 SSE 流"""
|
||||
try:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
|
||||
if line.startswith('event:'):
|
||||
continue
|
||||
|
||||
if line.startswith('data:'):
|
||||
data = line[5:].strip() # 去除 'data:' 后的空格
|
||||
if not data or data == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
# 处理 endpoint 事件(相对路径或绝对路径)
|
||||
if not self._endpoint_url:
|
||||
# 如果是相对路径,拼接成完整 URL
|
||||
if data.startswith('/'):
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
parsed = urlparse(self.server_url)
|
||||
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
|
||||
else:
|
||||
self._endpoint_url = data
|
||||
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
|
||||
continue
|
||||
|
||||
# 处理 message 事件
|
||||
message = json.loads(data)
|
||||
request_id = message.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"SSE 流读取错误: {e}")
|
||||
|
||||
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通过 SSE endpoint 发送请求"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
request_id = request["id"]
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await asyncio.wait_for(future, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError("请求超时")
|
||||
|
||||
async def _send_sse_notification(self, notification: Dict[str, Any]):
|
||||
"""发送通知(无需响应)"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
@@ -107,18 +219,12 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=init_request
|
||||
) as response:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
@@ -127,21 +233,16 @@ class SimpleMCPClient:
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
# 获取 session ID
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
# 发送 initialized 通知
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=initialized_notification
|
||||
) as notif_response:
|
||||
async with self._session.post(self.server_url, json=initialized_notification):
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
@@ -149,12 +250,18 @@ class SimpleMCPClient:
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
# 基础 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
# 合并 connection_config 中的自定义 headers
|
||||
custom_headers = self.connection_config.get("headers", {})
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# 处理认证配置(认证 headers 优先级更高)
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
|
||||
@@ -178,7 +285,7 @@ class SimpleMCPClient:
|
||||
return headers
|
||||
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息"""
|
||||
"""发送初始化消息(WebSocket)"""
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
@@ -186,124 +293,90 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
result = response_data.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
await self._websocket.send(json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}))
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||
|
||||
result = response_data.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response_data.get("result", {})
|
||||
|
||||
def _get_request_id(self) -> int:
|
||||
"""生成请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
# 处理响应
|
||||
if "id" in data:
|
||||
request_id = str(data["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
async for message in self._websocket:
|
||||
data = json.loads(message)
|
||||
request_id = data.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
except ConnectionClosed:
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理异常: {e}")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
request_id = str(request_data["id"])
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
def _get_request_id(self) -> str:
|
||||
"""获取请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
logger.error(f"WebSocket 消息处理错误: {e}")
|
||||
|
||||
@@ -74,6 +74,7 @@ class WorkflowExecutor:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
@@ -114,7 +115,7 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [('user', user_message)],
|
||||
"messages": conversation_messages,
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import add
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AnyMessage, AIMessage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[tuple[str, str]], add]
|
||||
messages: list[dict[str, str]]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
@@ -154,7 +154,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
超时时间
|
||||
"""
|
||||
return 60
|
||||
return settings.WORKFLOW_NODE_TIMEOUT
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
@@ -356,6 +357,7 @@ class BaseNode(ABC):
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
|
||||
@@ -6,7 +6,6 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
else:
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
])
|
||||
output = "工作流已完成"
|
||||
|
||||
# 统计信息(用于日志)
|
||||
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
}
|
||||
])
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
|
||||
source_node_id = edge.get("source")
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
print("="*50)
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
logger.info(
|
||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_output
|
||||
}
|
||||
])
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
|
||||
@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
default='user',
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
default="",
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
|
||||
return v.lower()
|
||||
|
||||
|
||||
class MemoryWindowSetting(BaseModel):
|
||||
enable: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆"
|
||||
)
|
||||
|
||||
enable_window: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆窗口"
|
||||
)
|
||||
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
description="记忆窗口大小"
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
memory: MemoryWindowSetting = Field(
|
||||
...,
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
messages.append({"role": "system", "content": content})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
# if self.typed_config.memory.enable_window:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
"config": {
|
||||
|
||||
@@ -3,8 +3,9 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.db import get_db_read, get_db_context
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
with get_db_context() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
)
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, state),
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
|
||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||
|
||||
|
||||
class ToolOldConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
"""试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
@@ -59,7 +60,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -210,7 +211,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -511,7 +512,6 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("多 Agent 流式聊天被中断")
|
||||
@@ -537,83 +537,19 @@ class AppChatService:
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
|
||||
input_data = {"message":message, "variables": variables,
|
||||
"conversation_id": str(conversation_id)}
|
||||
inconfig = workflow_service.get_workflow_config(app_id)
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = workflow_service.create_execution(
|
||||
workflow_config_id=inconfig.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
conversation_id=str(conversation_id),
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return await workflow_service.run(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
)
|
||||
else:
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def workflow_chat_stream(
|
||||
self,
|
||||
@@ -632,62 +568,21 @@ class AppChatService:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
input_data = {"message": message, "variables": variables,
|
||||
"conversation_id": str(conversation_id)}
|
||||
inconfig = workflow_service.get_workflow_config(app_id)
|
||||
# 2. 创建执行记录
|
||||
execution = workflow_service.create_execution(
|
||||
workflow_config_id=inconfig.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
conversation_id=str(conversation_id),
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
):
|
||||
yield event
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 流式执行工作流
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in workflow_service._run_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
# 发送错误事件
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution.execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
|
||||
@@ -13,12 +13,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.conversation_repository import ConversationRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import type_mapping, EmotionType, EmotionSubject
|
||||
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -1198,18 +1200,17 @@ async def analytics_memory_types(
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
统计9种记忆类型的数量和百分比
|
||||
统计8种记忆类型的数量和百分比
|
||||
|
||||
计算规则:
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
|
||||
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
||||
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
|
||||
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
@@ -1229,7 +1230,6 @@ async def analytics_memory_types(
|
||||
- PERCEPTUAL_MEMORY: 感知记忆
|
||||
- WORKING_MEMORY: 工作记忆
|
||||
- SHORT_TERM_MEMORY: 短期记忆
|
||||
- LONG_TERM_MEMORY: 长期记忆
|
||||
- EXPLICIT_MEMORY: 显性记忆
|
||||
- IMPLICIT_MEMORY: 隐性记忆
|
||||
- EMOTIONAL_MEMORY: 情绪记忆
|
||||
@@ -1239,40 +1239,78 @@ async def analytics_memory_types(
|
||||
# 初始化基础服务
|
||||
base_service = MemoryBaseService()
|
||||
|
||||
# 定义需要查询的基础节点类型
|
||||
node_types = {
|
||||
"Statement": "Statement",
|
||||
"Entity": "ExtractedEntity",
|
||||
"Chunk": "Chunk"
|
||||
}
|
||||
# 初始化感知记忆服务
|
||||
perceptual_service = MemoryPerceptualService(db)
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
node_counts = {}
|
||||
# 获取感知记忆数量
|
||||
if end_user_id:
|
||||
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
|
||||
perceptual_count = perceptual_stats.get("total", 0)
|
||||
else:
|
||||
perceptual_count = 0
|
||||
|
||||
# 查询每种节点类型的数量
|
||||
for key, node_type in node_types.items():
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
# 获取工作记忆数量(基于会话数量)
|
||||
work_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
conversation_repo = ConversationRepository(db)
|
||||
conversations = conversation_repo.get_conversation_by_user_id(
|
||||
user_id=uuid.UUID(end_user_id),
|
||||
limit=100, # 获取更多会话以准确统计
|
||||
is_activate=True
|
||||
)
|
||||
work_count = len(conversations)
|
||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||
work_count = 0
|
||||
|
||||
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
|
||||
implicit_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
# 查询 Statement 节点数量
|
||||
query = """
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query)
|
||||
|
||||
# 提取计数结果
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
node_counts[key] = count
|
||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
# 取三分之一作为隐性记忆数量
|
||||
implicit_count = round(statement_count / 3)
|
||||
logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
implicit_count = 0
|
||||
|
||||
# 获取各节点类型的数量
|
||||
statement_count = node_counts.get("Statement", 0)
|
||||
entity_count = node_counts.get("Entity", 0)
|
||||
chunk_count = node_counts.get("Chunk", 0)
|
||||
# 原有的基于行为习惯的统计方式(已注释)
|
||||
# implicit_count = 0
|
||||
# if end_user_id:
|
||||
# try:
|
||||
# implicit_service = ImplicitMemoryService(db, end_user_id)
|
||||
# behavior_habits = await implicit_service.get_behavior_habits(
|
||||
# user_id=end_user_id
|
||||
# )
|
||||
# implicit_count = len(behavior_habits)
|
||||
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
# implicit_count = 0
|
||||
|
||||
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
|
||||
short_term_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
short_term_service = ShortService(end_user_id)
|
||||
short_term_data = short_term_service.get_short_databasets()
|
||||
# 统计 short_term 数组的长度
|
||||
if short_term_data:
|
||||
short_term_count = len(short_term_data)
|
||||
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}")
|
||||
short_term_count = 0
|
||||
|
||||
# 获取用户的遗忘阈值配置
|
||||
forgetting_threshold = 0.3 # 默认值
|
||||
@@ -1298,17 +1336,16 @@ async def analytics_memory_types(
|
||||
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
||||
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
||||
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
|
||||
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
||||
|
||||
# 按规则计算9种记忆类型的数量(使用英文枚举作为key)
|
||||
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||
memory_counts = {
|
||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
||||
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
|
||||
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
||||
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
||||
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3)
|
||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
工作流服务层
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any, Annotated, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,15 +15,16 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.conversation_repository import MessageRepository
|
||||
from app.models.conversation_model import Message
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,6 +37,7 @@ class WorkflowService:
|
||||
self.config_repo = WorkflowConfigRepository(db)
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
self.message_repo = MessageRepository(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
@@ -418,14 +419,13 @@ class WorkflowService:
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
workspace_id:
|
||||
config:
|
||||
payload:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
执行结果(非流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
@@ -438,7 +438,8 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
@@ -461,7 +462,7 @@ class WorkflowService:
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
@@ -500,8 +501,11 @@ class WorkflowService:
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -517,6 +521,17 @@ class WorkflowService:
|
||||
"completed",
|
||||
output_data=result
|
||||
)
|
||||
final_messages = result.get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
@@ -529,6 +544,7 @@ class WorkflowService:
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
@@ -559,6 +575,7 @@ class WorkflowService:
|
||||
"""运行工作流(流式)
|
||||
|
||||
Args:
|
||||
workspace_id:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
config: 存储类型(可选)
|
||||
@@ -601,7 +618,7 @@ class WorkflowService:
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
@@ -638,17 +655,46 @@ class WorkflowService:
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=end_user_id
|
||||
):
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
@@ -667,6 +713,8 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
@deprecated(reason="This method is deprecated. "
|
||||
"Please use WorkflowService.run / run_stream instead.")
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -819,6 +867,7 @@ class WorkflowService:
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
|
||||
@@ -136,7 +136,8 @@ dependencies = [
|
||||
"markdown-to-json==2.1.1",
|
||||
"valkey==6.0.2",
|
||||
"python-calamine>=0.4.0",
|
||||
"xlrd==2.0.2"
|
||||
"xlrd==2.0.2",
|
||||
"deprecated>=1.3.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
</div>
|
||||
}
|
||||
{/* 消息气泡框 */}
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, {
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
|
||||
// 错误消息样式(内容为null且非助手消息)
|
||||
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
|
||||
// 助手消息样式
|
||||
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
</div>
|
||||
{/* 底部标签(如时间戳、用户名等) */}
|
||||
{labelPosition === 'bottom' &&
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular">
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
|
||||
{labelFormat(item)}
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -1265,6 +1265,7 @@ export const en = {
|
||||
emotionLine: 'Emotion Changes Over Time',
|
||||
interaction: 'Interaction Frequency & Relationship Stages',
|
||||
timelines_memory: 'All',
|
||||
Chunk: 'Chunk',
|
||||
MemorySummary: 'Long-term Accumulation',
|
||||
Statement: 'Emotional Memory',
|
||||
ExtractedEntity: 'Episodic Memory',
|
||||
@@ -1786,6 +1787,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
temperature: 'Temperature',
|
||||
max_tokens: 'Max Tokens',
|
||||
context: 'Context',
|
||||
memory: 'Memory',
|
||||
enable_window: 'Memory Window',
|
||||
inner: 'Built-in',
|
||||
},
|
||||
start: {
|
||||
variables: 'Input Fields',
|
||||
|
||||
@@ -1343,6 +1343,7 @@ export const zh = {
|
||||
emotionLine: '情绪随时间变化',
|
||||
interaction: '互动频率 & 关系阶段',
|
||||
timelines_memory: '全部',
|
||||
Chunk: '工作记忆',
|
||||
MemorySummary: '长期沉淀',
|
||||
Statement: '情绪记忆',
|
||||
ExtractedEntity: '情景记忆',
|
||||
@@ -1883,6 +1884,9 @@ export const zh = {
|
||||
temperature: '温度',
|
||||
max_tokens: '最大令牌数',
|
||||
context: '上下文',
|
||||
memory: '记忆',
|
||||
enable_window: '记忆窗口',
|
||||
inner: '内置',
|
||||
},
|
||||
start: {
|
||||
variables: '输入字段',
|
||||
|
||||
@@ -176,6 +176,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
if (response?.knowledge_retrieval?.knowledge_bases?.length) {
|
||||
getDefaultKnowledgeList(response)
|
||||
}
|
||||
if (response?.tools?.length) {
|
||||
setToolList(response?.tools)
|
||||
}
|
||||
}).finally(() => {
|
||||
setLoading(false)
|
||||
})
|
||||
|
||||
@@ -79,8 +79,6 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi
|
||||
}
|
||||
}, [data])
|
||||
|
||||
console.log('toolList', toolList)
|
||||
|
||||
const handleAddTool = () => {
|
||||
toolModalRef.current?.handleOpen()
|
||||
}
|
||||
|
||||
@@ -259,9 +259,10 @@ const Conversation: FC = () => {
|
||||
</div>
|
||||
|
||||
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
|
||||
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
|
||||
<Chat
|
||||
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />}
|
||||
contentClassName="rb:h-[calc(100%-152px)]"
|
||||
empty={<Empty url={BgImg} className="rb:h-full" size={[320,180]} subTitle={t('memoryConversation.emptyDesc')} />}
|
||||
contentClassName="rb:h-[calc(100%-152px)] "
|
||||
data={chatList}
|
||||
streamLoading={streamLoading}
|
||||
loading={loading}
|
||||
@@ -290,6 +291,7 @@ const Conversation: FC = () => {
|
||||
</Flex>
|
||||
</Form>
|
||||
</Chat>
|
||||
</div>
|
||||
</div>
|
||||
</Flex>
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import React, { useState, useImperativeHandle, forwardRef, useRef } from 'react';
|
||||
import { Button, Input, Space, Typography, Tooltip, message, List } from 'antd';
|
||||
import { PlusOutlined, EditOutlined, DeleteOutlined } from '@ant-design/icons';
|
||||
import { useState, useImperativeHandle, forwardRef, useRef } from 'react';
|
||||
import { Button, Space, List } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ChatVariable, AddChatVariableRef } from '../../types';
|
||||
import type { ChatVariableModalRef } from './types'
|
||||
|
||||
@@ -131,7 +131,7 @@ const EditableTable: React.FC<EditableTableProps> = ({
|
||||
const AddButton = ({ block = false }: { block?: boolean }) => (
|
||||
<Button
|
||||
type={block ? "dashed" : "text"}
|
||||
icon={<PlusOutlined />}
|
||||
icon={block ? undefined : <PlusOutlined />}
|
||||
onClick={() => add(createNewRow())}
|
||||
size="small"
|
||||
block={block}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import { type FC } from "react";
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Form, Row, Col, Divider, Switch, Slider } from 'antd'
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
import MessageEditor from '../MessageEditor'
|
||||
|
||||
const MemoryConfig: FC<{ options: Suggestion[]; parentName: string; }> = ({
|
||||
options,
|
||||
parentName
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const form = Form.useFormInstance();
|
||||
const values = Form.useWatch([], form) || {}
|
||||
|
||||
console.log('MemoryConfig', values)
|
||||
|
||||
const handleChangeEnable = (value: boolean) => {
|
||||
if (value) {
|
||||
form.setFieldsValue({
|
||||
memory: {
|
||||
...form.getFieldValue(parentName),
|
||||
enable_window: false,
|
||||
window_size: 20,
|
||||
messages: "{{sys.message}}"
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{values?.memory?.enable && <>
|
||||
<div className="rb:flex rb:items-center rb:justify-between rb:py-1.5 rb:px-2 rb:bg-[#F6F8FC] rb:rounded-md rb:mb-2">
|
||||
{t('workflow.config.llm.memory')}
|
||||
<span>{t('workflow.config.llm.inner')}</span>
|
||||
</div>
|
||||
<Form.Item layout="horizontal" name={[parentName, 'messages']}>
|
||||
<MessageEditor
|
||||
title="USER"
|
||||
isArray={false}
|
||||
parentName={[parentName, 'messages']}
|
||||
options={options}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Divider />
|
||||
</>}
|
||||
<Form.Item layout="horizontal" name={[parentName, 'enable']} label={t('workflow.config.llm.memory')}>
|
||||
<Switch onChange={handleChangeEnable} />
|
||||
</Form.Item>
|
||||
{values?.memory?.enable && <>
|
||||
<Row className="rb:mb-3">
|
||||
<Col span={10}>
|
||||
<Form.Item layout="horizontal" name={[parentName, 'enable_window']} noStyle>
|
||||
<Switch />
|
||||
</Form.Item>
|
||||
<span className="rb:ml-2">{t('workflow.config.llm.enable_window')}</span>
|
||||
</Col>
|
||||
<Col span={14}>
|
||||
<Form.Item layout="horizontal" name={[parentName, 'window_size']} noStyle>
|
||||
<Slider min={1} max={100} step={1} className="rb:my-0!" disabled={!values?.memory?.enable_window} />
|
||||
</Form.Item>
|
||||
</Col>
|
||||
</Row>
|
||||
</>}
|
||||
</>
|
||||
);
|
||||
};
|
||||
export default MemoryConfig;
|
||||
@@ -127,7 +127,7 @@ const MessageEditor: FC<MessageEditor> = ({
|
||||
</Space>
|
||||
);
|
||||
})}
|
||||
<Form.Item>
|
||||
<Form.Item noStyle>
|
||||
<Button type="dashed" onClick={() => handleAdd(add)} block>
|
||||
+{t('workflow.addMessage')}
|
||||
</Button>
|
||||
|
||||
@@ -22,6 +22,7 @@ import ConditionList from './ConditionList'
|
||||
import CycleVarsList from './CycleVarsList'
|
||||
import AssignmentList from './AssignmentList'
|
||||
import ToolConfig from './ToolConfig'
|
||||
import MemoryConfig from './MemoryConfig'
|
||||
// import { calculateVariableList } from './utils/variableListCalculator'
|
||||
|
||||
interface PropertiesProps {
|
||||
@@ -1230,6 +1231,20 @@ const Properties: FC<PropertiesProps> = ({
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
if (config.type === 'memoryConfig') {
|
||||
return (
|
||||
<Form.Item
|
||||
key={key}
|
||||
name={key}
|
||||
noStyle
|
||||
>
|
||||
<MemoryConfig
|
||||
parentName={key}
|
||||
options={getFilteredVariableList('llm')}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<Form.Item
|
||||
|
||||
@@ -135,6 +135,14 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
readonly: true
|
||||
},
|
||||
]
|
||||
},
|
||||
memory: {
|
||||
type: 'memoryConfig',
|
||||
defaultValue: {
|
||||
enable: false,
|
||||
enable_window: false,
|
||||
window_size: 20
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -750,10 +758,6 @@ export const outputVariable: { [key: string]: OutputVariable } = {
|
||||
{ name: "body", type: "string" },
|
||||
{ name: "status_code", type: "number" },
|
||||
],
|
||||
error: [
|
||||
{ name: "error_message", type: "string" },
|
||||
{ name: "error_type", type: "string" },
|
||||
]
|
||||
},
|
||||
'tool': {
|
||||
default: [
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
|
||||
import { register } from '@antv/x6-react-shape';
|
||||
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
|
||||
import type { WorkflowConfig, NodeProperties } from '../types';
|
||||
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
|
||||
@@ -35,6 +35,8 @@ export interface UseWorkflowGraphReturn {
|
||||
copyEvent: () => boolean | void;
|
||||
parseEvent: () => boolean | void;
|
||||
handleSave: (flag?: boolean) => Promise<unknown>;
|
||||
chatVariables: ChatVariable[];
|
||||
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
||||
}
|
||||
|
||||
export const edge_color = '#155EEF';
|
||||
@@ -54,6 +56,7 @@ export const useWorkflowGraph = ({
|
||||
const [canRedo, setCanRedo] = useState(false);
|
||||
const [isHandMode, setIsHandMode] = useState(false);
|
||||
const [config, setConfig] = useState<WorkflowConfig | null>(null);
|
||||
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
getConfig()
|
||||
@@ -63,16 +66,15 @@ export const useWorkflowGraph = ({
|
||||
getWorkflowConfig(id)
|
||||
.then(res => {
|
||||
const { variables, ...rest } = res as WorkflowConfig
|
||||
setConfig({
|
||||
...rest,
|
||||
variables: variables.map(v => {
|
||||
const { default: _, ...cleanV } = v
|
||||
return {
|
||||
...cleanV,
|
||||
defaultValue: v.default ?? ''
|
||||
}
|
||||
})
|
||||
const initChatVariables = variables.map(v => {
|
||||
const { default: _, ...cleanV } = v
|
||||
return {
|
||||
...cleanV,
|
||||
defaultValue: v.default ?? ''
|
||||
}
|
||||
})
|
||||
setChatVariables(initChatVariables)
|
||||
setConfig({ ...rest, variables: initChatVariables })
|
||||
})
|
||||
}
|
||||
|
||||
@@ -94,7 +96,17 @@ export const useWorkflowGraph = ({
|
||||
|
||||
if (nodeLibraryConfig?.config) {
|
||||
Object.keys(nodeLibraryConfig.config).forEach(key => {
|
||||
if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||
if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||
const { memory, messages } = config as any;
|
||||
if (memory?.enable && messages && messages.length > 0) {
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
nodeLibraryConfig.config[key].defaultValue = {
|
||||
...memory,
|
||||
messages: lastMessage.content
|
||||
}
|
||||
nodeLibraryConfig.config.messages.defaultValue.splice(-1, 1)
|
||||
}
|
||||
} else if (key === 'knowledge_retrieval' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) {
|
||||
const { query, ...rest } = config
|
||||
nodeLibraryConfig.config[key].defaultValue = {
|
||||
...rest
|
||||
@@ -917,13 +929,13 @@ export const useWorkflowGraph = ({
|
||||
|
||||
const params = {
|
||||
...config,
|
||||
variables: config.variables.map(v => {
|
||||
const { defaultValue, ...cleanV } = v
|
||||
return {
|
||||
...cleanV,
|
||||
default: defaultValue ?? ''
|
||||
}
|
||||
}),
|
||||
variables: chatVariables.map(v => {
|
||||
const { defaultValue, ...cleanV } = v
|
||||
return {
|
||||
...cleanV,
|
||||
default: defaultValue ?? ''
|
||||
}
|
||||
}),
|
||||
nodes: nodes.map((node: Node) => {
|
||||
const data = node.getData();
|
||||
const position = node.getPosition();
|
||||
@@ -931,7 +943,15 @@ export const useWorkflowGraph = ({
|
||||
|
||||
if (data.config) {
|
||||
Object.keys(data.config).forEach(key => {
|
||||
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
|
||||
if (key === 'memory' && data.config[key] && 'defaultValue' in data.config[key]) {
|
||||
const { messages, ...rest } = data.config[key].defaultValue
|
||||
let memoryMessage = { role: 'USER', content: data.config[key].defaultValue.messages }
|
||||
itemConfig = {
|
||||
...itemConfig,
|
||||
messages: rest.enable ? [...itemConfig.messages, memoryMessage] : itemConfig.messages,
|
||||
memory: { ...rest },
|
||||
}
|
||||
} else if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_variables') {
|
||||
let group_variables = data.config.group.defaultValue ? {} : data.config[key].defaultValue
|
||||
if (data.config.group.defaultValue) {
|
||||
data.config[key].defaultValue.map((vo: any) => {
|
||||
@@ -1077,5 +1097,7 @@ export const useWorkflowGraph = ({
|
||||
copyEvent,
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
};
|
||||
};
|
||||
|
||||
@@ -8,7 +8,7 @@ import PortClickHandler from './components/PortClickHandler';
|
||||
import { useWorkflowGraph } from './hooks/useWorkflowGraph';
|
||||
import type { WorkflowRef } from '@/views/ApplicationConfig/types'
|
||||
import Chat from './components/Chat/Chat';
|
||||
import type { ChatRef, AddChatVariableRef, ChatVariable } from './types'
|
||||
import type { ChatRef, AddChatVariableRef } from './types'
|
||||
import arrowIcon from '@/assets/images/workflow/arrow.png'
|
||||
import AddChatVariable from './components/AddChatVariable';
|
||||
|
||||
@@ -21,7 +21,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
// 使用自定义Hook初始化工作流图
|
||||
const {
|
||||
config,
|
||||
setConfig,
|
||||
graphRef,
|
||||
selectedNode,
|
||||
setSelectedNode,
|
||||
@@ -38,6 +37,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
copyEvent,
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
} = useWorkflowGraph({ containerRef, miniMapRef });
|
||||
|
||||
const onDragOver = (event: React.DragEvent) => {
|
||||
@@ -52,15 +53,6 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
const addVariable = () => {
|
||||
addChatVariableRef.current?.handleOpen()
|
||||
}
|
||||
const handleUpdateChatVariable = (variables: ChatVariable[]) => {
|
||||
setConfig(prev => {
|
||||
if (!prev) return null
|
||||
return {
|
||||
...prev,
|
||||
variables
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleSave,
|
||||
@@ -125,8 +117,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
|
||||
<AddChatVariable
|
||||
ref={addChatVariableRef}
|
||||
variables={config?.variables}
|
||||
onChange={handleUpdateChatVariable}
|
||||
variables={chatVariables}
|
||||
onChange={setChatVariables}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user