Merge branch 'develop' into fix/workflow_zy

This commit is contained in:
yingzhao
2026-01-16 12:30:59 +08:00
committed by GitHub
42 changed files with 1054 additions and 778 deletions

View File

@@ -11,15 +11,16 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.models.app_model import AppType, App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
@@ -405,6 +406,15 @@ async def draft_run(
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
)
payload.user_id = str(new_end_user.id)
# 处理会话ID创建或验证
conversation_id = await draft_service._ensure_conversation(
conversation_id=payload.conversation_id,

View File

@@ -74,7 +74,7 @@ def get_multi_agent_configs(
"app_id": str(app_id),
"default_model_config_id": None,
"model_parameters": None,
"orchestration_mode": "conditional",
"orchestration_mode": "supervisor",
"sub_agents": [],
"routing_rules": [],
"execution_config": {

View File

@@ -466,7 +466,7 @@ async def chat(
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config= payload.agent_config,
config=agent_config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
@@ -565,11 +565,12 @@ async def chat(
config = workflow_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
@@ -601,7 +602,7 @@ async def chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,

View File

@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
@@ -96,6 +96,7 @@ async def create_workflow_config(
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
@@ -199,10 +200,10 @@ async def create_workflow_config(
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
@@ -243,11 +244,11 @@ async def delete_workflow_config(
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
@@ -312,12 +313,12 @@ async def validate_workflow_config(
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
@@ -365,10 +366,10 @@ async def get_workflow_executions(
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
@@ -417,16 +418,14 @@ async def get_workflow_execution(
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
@@ -487,22 +486,22 @@ async def run_workflow(
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
@@ -554,10 +553,10 @@ async def run_workflow(
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
return fail(code=e.code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(

View File

@@ -7,17 +7,18 @@ from dotenv import load_dotenv
load_dotenv()
class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
# Database configuration (Postgres)
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
@@ -37,7 +38,7 @@ class Settings:
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
@@ -48,7 +49,7 @@ class Settings:
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
# Xinference configuration
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
@@ -57,17 +58,17 @@ class Settings:
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
# LLM Request Configuration
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
# JWT Token Configuration
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
@@ -86,19 +87,19 @@ class Settings:
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
# ========================================================================
# Superuser settings (internal defaults)
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
# Generic File Upload (internal)
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
@@ -123,7 +124,7 @@ class Settings:
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
# Sensitive Data Filtering
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
@@ -142,7 +143,6 @@ class Settings:
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
@@ -150,15 +150,15 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
@@ -167,7 +167,10 @@ class Settings:
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.
@@ -182,7 +185,7 @@ class Settings:
if filename:
return str(base_path / filename)
return str(base_path)
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.

View File

@@ -110,24 +110,24 @@ HTTP_MAPPING = {
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.TENANT_NOT_FOUND: 400,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
BizCode.NOT_FOUND: 400,
BizCode.USER_NOT_FOUND: 200,
BizCode.WORKSPACE_NOT_FOUND: 404,
BizCode.MODEL_NOT_FOUND: 404,
BizCode.KNOWLEDGE_NOT_FOUND: 404,
BizCode.DOCUMENT_NOT_FOUND: 404,
BizCode.FILE_NOT_FOUND: 404,
BizCode.APP_NOT_FOUND: 404,
BizCode.RELEASE_NOT_FOUND: 404,
BizCode.WORKSPACE_NOT_FOUND: 400,
BizCode.MODEL_NOT_FOUND: 400,
BizCode.KNOWLEDGE_NOT_FOUND: 400,
BizCode.DOCUMENT_NOT_FOUND: 400,
BizCode.FILE_NOT_FOUND: 400,
BizCode.APP_NOT_FOUND: 400,
BizCode.RELEASE_NOT_FOUND: 400,
BizCode.DUPLICATE_NAME: 409,
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
BizCode.AGENT_CONFIG_MISSING: 400,
BizCode.SHARE_DISABLED: 403,

View File

@@ -425,15 +425,9 @@ async def Input_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
@@ -539,31 +533,11 @@ async def Input_Summary(
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {

View File

@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
if not isinstance(operation, dict):
continue
summary = operation.get("summary", "")
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"summary": summary if summary else operation_id,
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),

View File

@@ -232,7 +232,7 @@ class LangchainAdapter:
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
@@ -241,7 +241,7 @@ class LangchainAdapter:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["regex"] = param.pattern
field_kwargs["pattern"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type

View File

@@ -27,20 +27,22 @@ class SimpleMCPClient:
# 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://"))
self.is_sse = "/sse" in server_url.lower()
# 连接状态
self._websocket = None
self._session = None
self._request_id = 0
self._pending_requests = {}
self._server_capabilities = {}
self._endpoint_url = None # SSE endpoint URL
self._sse_task = None
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
async def connect(self):
@@ -57,47 +59,154 @@ class SimpleMCPClient:
async def disconnect(self):
"""断开连接"""
try:
if self._sse_task:
self._sse_task.cancel()
if self._websocket:
await self._websocket.close()
self._websocket = None
if self._session:
await self._session.close()
self._session = None
except Exception as e:
logger.error(f"断开连接失败: {e}")
async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
self._session = aiohttp.ClientSession(
headers=headers,
timeout=timeout
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
if self.is_sse:
await self._initialize_sse_session()
elif "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_sse_session(self):
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
try:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
# 启动 SSE 读取任务
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
# 等待获取 endpoint URL
for _ in range(10):
if self._endpoint_url:
break
await asyncio.sleep(1)
if not self._endpoint_url:
raise MCPConnectionError("未能获取 endpoint URL")
# 发送 initialize 请求到 endpoint
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
init_response = await self._send_sse_request(init_request)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = init_response.get("result", {})
self._server_capabilities = result.get("capabilities", {})
# 发送 initialized 通知
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
async def _read_sse_stream(self, response):
"""读取 SSE 流"""
try:
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('event:'):
continue
if line.startswith('data:'):
data = line[5:].strip() # 去除 'data:' 后的空格
if not data or data == '[DONE]':
continue
try:
# 处理 endpoint 事件(相对路径或绝对路径)
if not self._endpoint_url:
# 如果是相对路径,拼接成完整 URL
if data.startswith('/'):
from urllib.parse import urlparse, urlunparse
parsed = urlparse(self.server_url)
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
else:
self._endpoint_url = data
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
continue
# 处理 message 事件
message = json.loads(data)
request_id = message.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"SSE 流读取错误: {e}")
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""通过 SSE endpoint 发送请求"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
request_id = request["id"]
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
return await asyncio.wait_for(future, timeout=self.timeout)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError("请求超时")
async def _send_sse_notification(self, notification: Dict[str, Any]):
"""发送通知(无需响应)"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
@@ -107,18 +216,12 @@ class SimpleMCPClient:
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
try:
async with self._session.post(
self.server_url,
json=init_request
) as response:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
@@ -127,21 +230,16 @@ class SimpleMCPClient:
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(
self.server_url,
json=initialized_notification
) as notif_response:
async with self._session.post(self.server_url, json=initialized_notification):
pass
except aiohttp.ClientError as e:
@@ -149,12 +247,18 @@ class SimpleMCPClient:
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
# 基础 headers
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 添加认证头
# 合并 connection_config 中的自定义 headers
custom_headers = self.connection_config.get("headers", {})
if custom_headers:
headers.update(custom_headers)
# 处理认证配置(认证 headers 优先级更高)
auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none")
@@ -178,7 +282,7 @@ class SimpleMCPClient:
return headers
async def _send_initialize(self):
"""发送初始化消息"""
"""发送初始化消息WebSocket"""
init_message = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
@@ -186,124 +290,90 @@ class SimpleMCPClient:
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
await self._websocket.send(json.dumps(init_message))
response = await self._websocket.recv()
response_data = json.loads(response)
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.timeout
)
if "error" in response_data:
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
init_response = json.loads(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = response_data.get("result", {})
self._server_capabilities = result.get("capabilities", {})
await self._websocket.send(json.dumps({
"jsonrpc": "2.0",
"method": "notifications/initialized"
}))
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list"
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
result = response_data.get("result", {})
return result.get("tools", [])
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments}
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
error = response_data["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response_data.get("result", {})
def _get_request_id(self) -> int:
"""生成请求 ID"""
self._request_id += 1
return self._request_id
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
data = json.loads(message)
# 处理响应
if "id" in data:
request_id = str(data["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
async for message in self._websocket:
data = json.loads(message)
request_id = data.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
logger.info("WebSocket 连接已关闭")
except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
await self._websocket.send(json.dumps(request_data))
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
try:
async with self._session.post(
self.server_url,
json=request_data
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
logger.error(f"WebSocket 消息处理错误: {e}")

View File

@@ -74,6 +74,7 @@ class WorkflowExecutor:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_messages = input_data.get("conv_messages") or []
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or []
@@ -114,7 +115,7 @@ class WorkflowExecutor:
}
return {
"messages": [('user', user_message)],
"messages": conversation_messages,
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)

View File

@@ -7,13 +7,13 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from operator import add
from typing import Any
from langchain_core.messages import AnyMessage, AIMessage
from langchain_core.messages import AIMessage
from langgraph.config import get_stream_writer
from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# List of messages (append mode)
messages: Annotated[list[tuple[str, str]], add]
messages: list[dict[str, str]]
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
@@ -154,7 +154,7 @@ class BaseNode(ABC):
Returns:
超时时间
"""
return 60
return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
@@ -203,6 +203,7 @@ class BaseNode(ABC):
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var
@@ -356,6 +357,7 @@ class BaseNode(ABC):
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"messages": state["messages"],
"variables": state["variables"],
"runtime_vars": {
self.node_id: runtime_var

View File

@@ -6,7 +6,6 @@ End 节点实现
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成"
# 统计信息(用于日志)
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
"chunk_index": 1,
"is_suffix": False
})
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output}
return
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
source_node_id = edge.get("source")
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker
yield {"__final__": True, "result": output}
return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts
suffix_parts = []
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})

View File

@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
default='user',
description="消息角色system, user, assistant"
)
content: str = Field(
...,
default="",
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
return v.lower()
class MemoryWindowSetting(BaseModel):
enable: bool = Field(
default=False,
description="启用记忆"
)
enable_window: bool = Field(
default=False,
description="启用记忆窗口"
)
window_size: int = Field(
default=20,
description="记忆窗口大小"
)
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="上下文"
)
memory: MemoryWindowSetting = Field(
...,
description="对话上下文窗口"
)
# 简单模式
prompt: str | None = Field(
default=None,

View File

@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
messages_config = self.typed_config.messages
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
messages.append({"role": "system", "content": content})
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
messages.append({"role": "assistant", "content": content})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
{"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {

View File

@@ -3,8 +3,9 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context
from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
class MemoryReadNode(BaseNode):
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id),
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)
write_message_task.delay(
end_user_id,
self._render_template(self.typed_config.message, state),
str(self.typed_config.config_id),
"neo4j",
""
)
return "success"

View File

@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
class DraftRunRequest(BaseModel):
"""试运行请求"""
message: str = Field(..., description="用户消息")

View File

@@ -1,9 +1,51 @@
"""
情景记忆的请求和响应模型
"""
from abc import ABC
from pydantic import BaseModel, Field
from typing import Optional
type_mapping = {
"Person": "人物实体节点",
"Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
"Event": "事件实体节点",
"Concept": "概念实体节点",
"Time": "时间实体节点",
"Position": "职位实体节点",
"WorkRole": "职业实体节点",
"System": "系统实体节点",
"Policy": "政策实体节点",
"HistoricalPeriod": "历史时期实体节点",
"HistoricalState": "历史国家实体节点",
"HistoricalEvent": "历史事件实体节点",
"EconomicFactor": "经济因素实体节点",
"Condition": "条件实体节点",
"Numeric": "数值实体节点"
}
class EmotionType(ABC):
JOY_TYPE = "joy"
SURPRISE_TYPE = "surprise"
SANDROWNESS_TYPE = "sadness"
FEAR_TYPE = "fear"
ANGET_TYPE="anger"
NEUTRAL_TYPE="neutral"
EMOTION_MAPPING={
"joy":"愉快",
"surprise":"惊喜",
"sadness":"悲伤",
"fear":"恐惧",
"anger":"生气",
"neutral":"中性"
}
class EmotionSubject(ABC):
SUBJECT_MAPPING={
"self":"自己",
"other":"别人",
"object":"事物对象"
}
class EpisodicMemoryOverviewRequest(BaseModel):
"""情景记忆总览查询请求"""

View File

@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
@@ -59,7 +60,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -210,7 +211,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -511,7 +512,6 @@ class AppChatService:
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")
@@ -537,83 +537,19 @@ class AppChatService:
) -> Dict[str, Any]:
"""聊天(非流式)"""
workflow_service = WorkflowService(self.db)
input_data = {"message":message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id
)
return await workflow_service.run(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
)
# 更新执行结果
if result.get("status") == "completed":
workflow_service.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def workflow_chat_stream(
self,
@@ -622,7 +558,7 @@ class AppChatService:
config: WorkflowConfig,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
user_id: str = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
@@ -632,62 +568,21 @@ class AppChatService:
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
workflow_service = WorkflowService(self.db)
input_data = {"message": message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id
)
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
):
yield event
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 流式执行工作流
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in workflow_service._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ====================

View File

@@ -516,8 +516,16 @@ class ConversationService:
conversation_messages = self.get_conversation_history(
conversation_id=conversation_id,
max_history=30
max_history=20
)
if len(conversation_messages) == 0:
return ConversationOut(
theme="",
question=[],
summary="",
takeaways=[],
info_score=0,
)
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
system_prompt = f.read()
@@ -536,6 +544,7 @@ class ConversationService:
]
logger.info(f"Invoking LLM for conversation_id={conversation_id}")
model_resp = await llm.ainvoke(messages)
try:
if isinstance(model_resp.content, str):
result = json_repair.repair_json(model_resp.content, return_objects=True)

View File

@@ -245,7 +245,8 @@ class DraftRunService:
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
web_search: bool = True,
memory: bool = True
memory: bool = True,
sub_agent: bool = False
) -> Dict[str, Any]:
"""执行试运行(使用 LangChain Agent
@@ -435,7 +436,7 @@ class DraftRunService:
elapsed_time = time.time() - start_time
# 8. 保存会话消息
if agent_config.memory and agent_config.memory.get("enabled"):
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,

View File

@@ -9,7 +9,7 @@ import os
import re
import time
import uuid
from threading import Lock
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService:
"""Service for memory agent operations"""
def __init__(self):
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
@@ -83,12 +81,7 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def get_group_lock(self, group_id: str) -> Lock:
"""Get lock for specific group to prevent concurrent processing"""
with self.locks_lock:
if group_id not in self.user_locks:
self.user_locks[group_id] = Lock()
return self.user_locks[group_id]
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
@@ -417,241 +410,236 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
# Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id)
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
with group_lock:
# Step 1: Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
service_name="MemoryAgentService"
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
workflow_duration = time.time() - start
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
workflow_duration = time.time() - start
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
duration=total_duration,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
def _create_intermediate_key(self, output: Dict) -> str:
"""
Create a unique key for an intermediate output to detect duplicates.

View File

@@ -15,6 +15,8 @@ from neo4j.time import DateTime as Neo4jDateTime
import json
from datetime import datetime
from app.schemas.memory_episodic_schema import EmotionType
logger = logging.getLogger(__name__)
class MemoryEntityService:
@@ -123,7 +125,7 @@ class MemoryEntityService:
extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list)
# 合并所有数据并处理相同text的合并
all_timeline_data = memory_summary_list + statement_list + extracted_entity_list
all_timeline_data = memory_summary_list + statement_list
all_timeline_data = self._merge_same_text_items(all_timeline_data)
result = {
@@ -496,11 +498,11 @@ class MemoryEmotion:
length_data.append(emotion_intensity)
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
# 使用(emotion_type, created_at)作为分组键
if emotion_type in {"joy", "surprise"}:
if emotion_type in {EmotionType.JOY_TYPE, EmotionType.SURPRISE_TYPE}:
emotion_type='positive'
elif emotion_type in {"sadness", "fear", "anger"}:
elif emotion_type in {EmotionType.SANDROWNESS_TYPE, EmotionType.FEAR_TYPE, EmotionType.ANGET_TYPE}:
emotion_type='negative'
elif emotion_type=='neutral':
elif emotion_type==EmotionType.NEUTRAL_TYPE:
emotion_type='neutral'
group_key = (emotion_type, formatted_created_at)
# 累加emotion_intensity
@@ -595,7 +597,7 @@ class MemoryInteraction:
group_id = ori_data[0]['group_id']
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
if not Space_User:
return '不存在用户'
return []
user_id=Space_User[0]['id']
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)

View File

@@ -267,14 +267,14 @@ class MemoryForgetService:
elif node_type_label == 'memorysummary':
node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
# 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差
if last_access_dt:
if last_access_dt.tzinfo is None:
last_access_dt = last_access_dt.replace(tzinfo=timezone.utc)
last_access_timestamp = int(last_access_dt.timestamp())
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else:
last_access_timestamp = 0
@@ -520,7 +520,7 @@ class MemoryForgetService:
'average_activation_value': result['average_activation'],
'low_activation_nodes': result['low_activation_nodes'] or 0,
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
else:
activation_metrics = {
@@ -530,7 +530,7 @@ class MemoryForgetService:
'average_activation_value': None,
'low_activation_nodes': 0,
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
# 收集节点类型分布
@@ -620,7 +620,7 @@ class MemoryForgetService:
'merged_count': record.merged_count,
'average_activation': record.average_activation_value,
'total_nodes': record.total_nodes,
'execution_time': int(record.execution_time.timestamp())
'execution_time': int(record.execution_time.timestamp() * 1000)
})
api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据")
@@ -661,7 +661,7 @@ class MemoryForgetService:
'node_distribution': node_distribution,
'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
api_logger.info(

View File

@@ -1327,7 +1327,8 @@ class MultiAgentOrchestrator:
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
sub_agent=True
)
return result

View File

@@ -13,10 +13,15 @@ from typing import Any, Dict, List, Optional, Tuple
from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
@@ -1196,18 +1201,17 @@ async def analytics_memory_types(
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计9种记忆类型的数量和百分比
统计8种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
4. 长期记忆 (LONG_TERM_MEMORY) = entity
5. 性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
7. 情记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
5. 性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
7. 情记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args:
db: 数据库会话
@@ -1227,7 +1231,6 @@ async def analytics_memory_types(
- PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆
- LONG_TERM_MEMORY: 长期记忆
- EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
@@ -1237,40 +1240,78 @@ async def analytics_memory_types(
# 初始化基础服务
base_service = MemoryBaseService()
# 定义需要查询的基础节点类型
node_types = {
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk"
}
# 初始化感知记忆服务
perceptual_service = MemoryPerceptualService(db)
# 存储每种节点类型的计数
node_counts = {}
# 获取感知记忆数量
if end_user_id:
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
perceptual_count = perceptual_stats.get("total", 0)
else:
perceptual_count = 0
# 查询每种节点类型的数量
for key, node_type in node_types.items():
if end_user_id:
query = f"""
MATCH (n:{node_type})
# 获取工作记忆数量(基于会话数量
work_count = 0
if end_user_id:
try:
conversation_repo = ConversationRepository(db)
conversations = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
limit=100, # 获取更多会话以准确统计
is_activate=True
)
work_count = len(conversations)
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")
work_count = 0
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
implicit_count = 0
if end_user_id:
try:
# 查询 Statement 节点数量
query = """
MATCH (n:Statement)
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query)
# 提取计数结果
count = result[0]["count"] if result and len(result) > 0 else 0
node_counts[key] = count
statement_count = result[0]["count"] if result and len(result) > 0 else 0
# 取三分之一作为隐性记忆数量
implicit_count = round(statement_count / 3)
logger.debug(f"隐性记忆数量Statement数量的1/3: {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取Statement数量失败隐性记忆数量设为0: {str(e)}")
implicit_count = 0
# 获取各节点类型的数量
statement_count = node_counts.get("Statement", 0)
entity_count = node_counts.get("Entity", 0)
chunk_count = node_counts.get("Chunk", 0)
# 原有的基于行为习惯的统计方式(已注释)
# implicit_count = 0
# if end_user_id:
# try:
# implicit_service = ImplicitMemoryService(db, end_user_id)
# behavior_habits = await implicit_service.get_behavior_habits(
# user_id=end_user_id
# )
# implicit_count = len(behavior_habits)
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
# except Exception as e:
# logger.warning(f"获取行为习惯数量失败隐性记忆数量设为0: {str(e)}")
# implicit_count = 0
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
short_term_count = 0
if end_user_id:
try:
short_term_service = ShortService(end_user_id)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
short_term_count = len(short_term_data)
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取短期记忆数量失败短期记忆数量设为0: {str(e)}")
short_term_count = 0
# 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值
@@ -1296,17 +1337,16 @@ async def analytics_memory_types(
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算9种记忆类型的数量使用英文枚举作为key
# 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆Statement数量的1/3
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
@@ -1332,7 +1372,7 @@ async def analytics_graph_data(
db: Session,
end_user_id: str,
node_types: Optional[List[str]] = None,
limit: int = 100,
limit: int = 130,
depth: int = 1,
center_node_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -1416,12 +1456,14 @@ async def analytics_graph_data(
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_params = {
"group_id": end_user_id,
# "limit": limit
"limit": limit
}
# 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
@@ -1576,10 +1618,15 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
for field in allowed_fields:
if field in properties:
value = properties[field]
if str(field) == 'entity_type':
value=type_mapping.get(value,'')
if str(field)=="emotion_type":
value=EmotionType.EMOTION_MAPPING.get(value)
if str(field)=="emotion_subject":
value=EmotionSubject.SUBJECT_MAPPING.get(value)
# 清理 Neo4j 特殊类型
filtered_props[field] = _clean_neo4j_value(value)
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
print(filtered_props)
return filtered_props

View File

@@ -2,29 +2,28 @@
工作流服务层
"""
import datetime
import json
import logging
import uuid
import datetime
from typing import Any, Annotated, AsyncGenerator
from deprecated import deprecated
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db, get_db_context
from app.db import get_db
from app.models.conversation_model import Message
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.end_user_repository import EndUserRepository
from app.services.multi_agent_service import convert_uuids_to_str
from app.repositories.conversation_repository import MessageRepository
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message
from app.services.multi_agent_service import convert_uuids_to_str
logger = logging.getLogger(__name__)
@@ -37,6 +36,7 @@ class WorkflowService:
self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.message_repo = MessageRepository(db)
# ==================== 配置管理 ====================
@@ -418,14 +418,13 @@ class WorkflowService:
"""运行工作流
Args:
workspace_id:
config:
payload:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
执行结果(非流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
@@ -438,7 +437,8 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
@@ -461,7 +461,7 @@ class WorkflowService:
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
triggered_by=None,
conversation_id=conversation_id_uuid,
input_data=input_data
)
@@ -482,14 +482,6 @@ class WorkflowService:
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
with get_db_context() as db:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=payload.user_id,
original_user_id=payload.user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -500,14 +492,17 @@ class WorkflowService:
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=end_user_id
user_id=payload.user_id
)
# 更新执行结果
@@ -517,6 +512,17 @@ class WorkflowService:
"completed",
output_data=result
)
final_messages = result.get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
else:
self.update_execution_status(
execution.execution_id,
@@ -529,6 +535,7 @@ class WorkflowService:
"execution_id": execution.execution_id,
"status": result.get("status"),
"variables": result.get("variables"),
"messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
@@ -559,6 +566,7 @@ class WorkflowService:
"""运行工作流(流式)
Args:
workspace_id:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
@@ -601,7 +609,7 @@ class WorkflowService:
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
triggered_by=None,
conversation_id=conversation_id_uuid,
input_data=input_data
)
@@ -621,14 +629,6 @@ class WorkflowService:
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
with get_db_context() as db:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=payload.user_id,
original_user_id=payload.user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
@@ -638,17 +638,46 @@ class WorkflowService:
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in self._run_workflow_stream(
async for event in execute_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=end_user_id
user_id=payload.user_id
):
# 直接转发 executor 的事件(已经是正确的格式)
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=event.get("data")
)
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
for message in final_messages:
message_obj = Message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
)
self.message_repo.add_message(message_obj)
self.db.commit()
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed":
self.update_execution_status(
execution.execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
except Exception as e:
@@ -667,6 +696,8 @@ class WorkflowService:
}
}
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow(
self,
app_id: uuid.UUID,
@@ -819,6 +850,7 @@ class WorkflowService:
return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],

View File

@@ -8,9 +8,11 @@ import uuid
from typing import Dict, Any, Optional, Union
from datetime import datetime
from app.db import get_db_read
from app.models import AppRelease, WorkflowConfig
from app.models.agent_app_config_model import AgentConfig
from app.models.multi_agent_model import MultiAgentConfig
from app.repositories.workflow_repository import WorkflowConfigRepository
def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]:
@@ -24,18 +26,18 @@ def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]:
"""
if model_parameters is None:
return None
if isinstance(model_parameters, dict):
return model_parameters
# Pydantic v2
if hasattr(model_parameters, 'model_dump'):
return model_parameters.model_dump()
# Pydantic v1
if hasattr(model_parameters, 'dict'):
return model_parameters.dict()
# 其他情况尝试转换
try:
return dict(model_parameters)
@@ -54,17 +56,18 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]:
"""
if data is None:
return None
from app.schemas import ModelParameters
if isinstance(data, ModelParameters):
return data
if isinstance(data, dict):
return ModelParameters(**data)
return None
class AgentConfigProxy:
"""Proxy class for AgentConfig (legacy compatibility)"""
@@ -78,8 +81,7 @@ class AgentConfigProxy:
self.default_model_config_id = release.default_model_config_id
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
config_dict = release.config
agent_config = AgentConfig(
@@ -95,18 +97,17 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
return agent_config
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig:
config_dict = release.config
agent_config = MultiAgentConfig(
app_id=release.app_id,
default_model_config_id=release.default_model_config_id,
model_parameters=config_dict.get("model_parameters"),
master_agent_id=config_dict.get("master_agent_id"),
master_agent_name=config_dict.get("master_agent_name"),
orchestration_mode=config_dict.get("orchestration_mode", "conditional"),
orchestration_mode=config_dict.get("orchestration_mode", "supervisor"),
sub_agents=config_dict.get("sub_agents", []),
routing_rules=config_dict.get("routing_rules"),
execution_config=config_dict.get("execution_config", {}),
@@ -116,24 +117,26 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
return agent_config
def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig:
def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
config_dict = release.config
with get_db_read() as db:
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
source_config_id = source_config.id
config = WorkflowConfig(
id=release.id,
id=source_config_id,
app_id=release.app_id,
nodes=config_dict.get("nodes", []),
edges=config_dict.get("edges", []),
variables=config_dict.get("variables", []),
execution_config=config_dict.get("execution_config", {}),
triggers=config_dict.get("triggers", [])
)
return config
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
"""Convert dict to MultiAgentConfig model object
@@ -149,7 +152,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui
... "app_id": "uuid-here",
... "master_agent_id": "master-uuid",
... "master_agent_name": "Master Agent",
... "orchestration_mode": "conditional",
... "orchestration_mode": "supervisor",
... "sub_agents": [
... {"agent_id": "sub1-uuid", "name": "Sub Agent 1", "role": "specialist", "priority": 1},
... {"agent_id": "sub2-uuid", "name": "Sub Agent 2", "role": "specialist", "priority": 2}
@@ -186,7 +189,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui
app_id=final_app_id,
master_agent_id=master_agent_id,
master_agent_name=config_dict.get("master_agent_name"),
orchestration_mode=config_dict.get("orchestration_mode", "conditional"),
orchestration_mode=config_dict.get("orchestration_mode", "supervisor"),
sub_agents=config_dict.get("sub_agents", []),
routing_rules=config_dict.get("routing_rules"),
execution_config=config_dict.get("execution_config", {}),
@@ -276,7 +279,8 @@ def agent_config_to_dict(agent_config) -> Dict[str, Any]:
"id": str(agent_config.id),
"app_id": str(agent_config.app_id),
"system_prompt": agent_config.system_prompt,
"default_model_config_id": str(agent_config.default_model_config_id) if agent_config.default_model_config_id else None,
"default_model_config_id": str(
agent_config.default_model_config_id) if agent_config.default_model_config_id else None,
"model_parameters": agent_config.model_parameters,
"knowledge_retrieval": agent_config.knowledge_retrieval,
"memory": agent_config.memory,
@@ -338,6 +342,3 @@ def workflow_config_to_dict(workflow_config) -> Dict[str, Any]:
"created_at": workflow_config.created_at.isoformat() if workflow_config.created_at else None,
"updated_at": workflow_config.updated_at.isoformat() if workflow_config.updated_at else None
}

View File

@@ -136,7 +136,8 @@ dependencies = [
"markdown-to-json==2.1.1",
"valkey==6.0.2",
"python-calamine>=0.4.0",
"xlrd==2.0.2"
"xlrd==2.0.2",
"deprecated>=1.3.1",
]
[tool.pytest.ini_options]

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>使用帮助备份</title>
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="首页" transform="translate(-51, -358)" stroke="#5F6266">
<g id="使用帮助备份" transform="translate(51, 358)">
<g id="编组-35" transform="translate(2, 1.5)">
<path d="M6.13163525,1.97938144 L10.3064533,1.97938144 C11.2417733,1.97938144 12,2.70634106 12,3.6030912 L12,10.3762902 C12,11.2730404 11.2417733,12 10.3064533,12 L1.69354673,12 C0.758226699,12 0,11.2730404 0,10.3762902 L0,3.6030912 C0,2.70634106 0.758226699,1.97938144 1.69354673,1.97938144 L2.02448435,1.97938144 L2.02448435,1.97938144" id="路径"></path>
<path d="M3.52033177,0.78470905 L6.09032258,1.97938144 L6.09032258,1.97938144 L6.09032258,11.8762887 L2.51918436,10.2162282 C2.10022604,10.0214734 1.83225806,9.6014016 1.83225806,9.13938916 L1.83225806,1.86154804 C1.83225806,1.2057099 2.36391992,0.674048044 3.01975806,0.674048044 C3.19268295,0.674048044 3.36352144,0.711815028 3.52033177,0.78470905 Z" id="矩形" stroke-linejoin="round"></path>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>使用帮助</title>
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="首页" transform="translate(-24, -358)" stroke="#212332">
<g id="使用帮助" transform="translate(24, 358)">
<g id="编组-35" transform="translate(2, 1.5)">
<path d="M6.13163525,1.97938144 L10.3064533,1.97938144 C11.2417733,1.97938144 12,2.70634106 12,3.6030912 L12,10.3762902 C12,11.2730404 11.2417733,12 10.3064533,12 L1.69354673,12 C0.758226699,12 0,11.2730404 0,10.3762902 L0,3.6030912 C0,2.70634106 0.758226699,1.97938144 1.69354673,1.97938144 L2.02448435,1.97938144 L2.02448435,1.97938144" id="路径"></path>
<path d="M3.52033177,0.78470905 L6.09032258,1.97938144 L6.09032258,1.97938144 L6.09032258,11.8762887 L2.51918436,10.2162282 C2.10022604,10.0214734 1.83225806,9.6014016 1.83225806,9.13938916 L1.83225806,1.86154804 C1.83225806,1.2057099 2.36391992,0.674048044 3.01975806,0.674048044 C3.19268295,0.674048044 3.36352144,0.711815028 3.52033177,0.78470905 Z" id="矩形" stroke-linejoin="round"></path>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div>
}
{/* 消息气泡框 */}
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, {
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
// 错误消息样式内容为null且非助手消息
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
// 助手消息样式
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
</div>
{/* 底部标签(如时间戳、用户名等) */}
{labelPosition === 'bottom' &&
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular">
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
{labelFormat(item)}
</div>
}

View File

@@ -71,6 +71,7 @@ export const en = {
stepTwoDescription: 'Here you can create and manage spaces to organize models and data for different use cases.Once your spaces are ready, head to User Management to invite members and manage access.👉 Click User Management in the left menu to continue.',
stepThree: 'This is User Management',
stepThreeDescription: 'Here you can create users, assign roles, and manage access for your team.Once users are set up, the basic configuration is complete and youre ready to start using the platform 🎉',
finishButtonText: 'Get Started',
},
menu: {
home: 'Home',
@@ -91,6 +92,7 @@ export const en = {
memberManagement: 'Member Management',
memorySummary: 'Memory Summary',
memoryConversation: 'Memory Validation',
helpCenter: 'Help Center',
memorySummaryHandlers: 'Memory Summary Handlers',
createMemorySummary: 'Create Memory Summary',
memoryManagement: 'Memory Management',
@@ -183,14 +185,15 @@ export const en = {
createNewMemorySummary: 'Create New Memory Entry',
createNewApplication: 'Create New Application',
createNewApplicationDesc: 'Create a new application for this space',
createNewApplicationDesc: 'Build an app in just 3 minutes with zero-code drag-and-drop.',
createNewKnowledge: 'Create New Knowledge',
createNewKnowledgeDesc: 'Create a new memory entry',
createNewKnowledgeDesc: 'Transform your data into a fully searchable, dedicated knowledge base in seconds.',
memoryConversation: 'Memory Conversation',
memoryConversationDesc: 'Create a new memory conversation',
memoryConversationDesc: 'The more you use it, the better AI understands you.',
helpCenter: 'Help Center',
helpCenterDesc: 'One-stop support to answer your questions and get you started fast.',
memorySummary: 'View Memory Summary',
memorySummaryDesc: 'View Memory Summary Report',
@@ -618,6 +621,7 @@ export const en = {
retrieve:'Retrieve',
processing: 'Processing',
processingMode: 'Processing Mode',
processMsg: 'Processing Message',
dataSize: 'Data Size',
createUpdateTime: 'Create/Update Time',
operation: 'Operation',
@@ -1449,6 +1453,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
},
memoryConversation: {
searchPlaceholder: 'Enter user ID...',
chatEmpty:'Is there anything I can help you with',
userID: 'User ID',
testMemoryConversation: 'Test Memory Conversation',
conversationContent: 'Conversation Content',

View File

@@ -71,6 +71,7 @@ export const zh = {
stepTwoDescription: '你可以在这里创建和管理不同的空间,把模型和数据组织到具体的使用场景中。空间创建完成后,可以去 User Management 邀请成员、分配权限,一起协作使用。👉 点击左侧 User Management 继续。',
stepThree: '这里是用户管理页',
stepThreeDescription: '你可以在这里创建用户、分配角色,并管理团队成员的访问权限。完成用户设置后,基础配置就准备好了,可以开始实际使用平台的各项功能了 🎉',
finishButtonText: '开始使用',
},
menu: {
home: '首页',
@@ -782,14 +783,15 @@ export const zh = {
createNewMemorySummary: '创建新记忆条目',
createNewApplication: '创建新应用',
createNewApplicationDesc: '创建新空间应用',
createNewApplicationDesc: '零代码拖拽3分钟创应用',
createNewKnowledge: '创建知识',
createNewKnowledgeDesc: '创建新记忆条目',
createNewKnowledge: '创建知识',
createNewKnowledgeDesc: '秒变可搜索的专属知识库',
memoryConversation: '记忆对话',
memoryConversationDesc: '记忆对话',
memoryConversationDesc: '让AI越用越懂你',
helpCenter: '帮助中心',
helpCenterDesc: '一站式解决疑问快速上手',
memorySummary: '查看记忆摘要',
memorySummaryDesc: '查看记忆摘要报告',
@@ -1524,6 +1526,7 @@ export const zh = {
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体'
},
memoryConversation: {
chatEmpty:'有什么我可以帮您的吗?',
searchPlaceholder: '输入用户ID...',
userID: '用户ID',
testMemoryConversation: '测试记忆对话',

View File

@@ -1,3 +1,11 @@
/*
* @Description:
* @Version: 0.0.1
* @Author: yujiangping
* @Date: 2026-01-05 17:22:23
* @LastEditors: yujiangping
* @LastEditTime: 2026-01-15 21:02:43
*/
import { create } from 'zustand'
import enUS from 'antd/locale/en_US';
import zhCN from 'antd/locale/zh_CN';
@@ -12,6 +20,28 @@ import { timezoneToAntdLocaleMap } from '@/utils/timezones';
dayjs.extend(utc);
dayjs.extend(timezone);
// 自定义中文 locale修改 Tour 组件的按钮文字
const customZhCN: Locale = {
...zhCN,
Tour: {
...zhCN.Tour,
Next: '下一步',
Previous: '上一步',
Finish: '立即体验',
},
};
// 自定义英文 locale修改 Tour 组件的按钮文字
const customEnUS: Locale = {
...enUS,
Tour: {
...enUS.Tour,
Next: 'Next',
Previous: 'Previous',
Finish: 'Try it now',
},
};
interface I18nState {
language: string;
@@ -23,7 +53,7 @@ interface I18nState {
const initialTimeZone = localStorage.getItem('timeZone') || 'Asia/Shanghai'
const initialLanguage = localStorage.getItem('language') || 'en'
const initialLocale = initialLanguage === 'en' ? enUS : zhCN
const initialLocale = initialLanguage === 'en' ? customEnUS : customZhCN
i18n.changeLanguage(initialLanguage)
export const useI18n = create<I18nState>((set, get) => ({
@@ -32,7 +62,7 @@ export const useI18n = create<I18nState>((set, get) => ({
timeZone: initialTimeZone,
changeLanguage: (language: string) => {
i18n.changeLanguage(language)
const localeName = timezoneToAntdLocaleMap[language] || enUS;
const localeName = language === 'en' ? customEnUS : customZhCN;
set({ language: language, locale: localeName })
},
changeTimeZone: (timeZone: string) => {

View File

@@ -11,6 +11,7 @@ import Empty from '@/components/Empty'
import { formatDateTime } from '@/utils/format';
import { randomString } from '@/utils/common'
import BgImg from '@/assets/images/conversation/bg.png'
import ChatEmpty from '@/assets/images/empty/chatEmpty.png'
import Chat from '@/components/Chat'
import type { ChatItem } from '@/components/Chat/types'
import ButtonCheckbox from '@/components/ButtonCheckbox'
@@ -259,9 +260,10 @@ const Conversation: FC = () => {
</div>
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
<Chat
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-152px)]"
empty={<Empty url={ChatEmpty} className="rb:h-full" size={[320,180]} title={t('memoryConversation.chatEmpty')} subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-152px)] "
data={chatList}
streamLoading={streamLoading}
loading={loading}
@@ -290,6 +292,7 @@ const Conversation: FC = () => {
</Flex>
</Form>
</Chat>
</div>
</div>
</Flex>
)

View File

@@ -1,3 +1,11 @@
/*
* @Description:
* @Version: 0.0.1
* @Author: yujiangping
* @Date: 2026-01-05 17:22:23
* @LastEditors: yujiangping
* @LastEditTime: 2026-01-15 14:55:51
*/
import { type FC } from 'react'
import { useTranslation } from 'react-i18next'
import { useNavigate } from 'react-router-dom';
@@ -5,33 +13,49 @@ import Card from './Card';
import applicationIcon from '@/assets/images/menu/application_active.svg';
import knowledgeIcon from '@/assets/images/menu/knowledge_active.svg';
import memoryConversationIcon from '@/assets/images/menu/memoryConversation_active.svg';
import helpCenterIcon from '@/assets/images/menu/helpCenter_active.svg'
import arrowTopRight from '@/assets/images/home/arrow_top_right.svg';
const quickOperations = [
{ key: 'createNewApplication', url: '/application' },
{ key: 'createNewKnowledge', url: '/knowledge-base' },
{ key: 'memoryConversation', url: '/memory-conversation' },
{ key: 'helpCenter', url: '' },
]
const quickOperationIcons: {[key: string]: string | undefined} = {
createNewApplication: applicationIcon,
createNewKnowledge: knowledgeIcon,
memoryConversation: memoryConversationIcon,
helpCenter: helpCenterIcon
}
const QuickOperation:FC = () => {
const { t } = useTranslation()
const { t, i18n } = useTranslation()
const navigate = useNavigate();
const handleJump = (url: string | null) => {
if (url) {
navigate(url)
}else{
const currentLang = i18n.language;
const lang = currentLang === 'zh' ? 'zh' : 'en';
const helpUrl = `https://docs.redbearai.com/s/${lang}-memorybear`;
// 创建隐藏的 a 标签来避免弹窗拦截
const link = document.createElement('a');
link.href = helpUrl;
link.target = '_blank';
link.rel = 'noopener noreferrer';
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
}
}
return (
<Card
title={t('dashboard.quickOperation')}
>
<div className="rb:grid rb:grid-cols-3 rb:gap-[16px]">
<div className="rb:grid rb:grid-cols-4 rb:gap-[16px]">
{quickOperations.map(item => (
<div key={item.key} className="rb:rounded-[8px] rb:p-[20px_16px] rb:border-1 rb:border-[#DFE4ED] rb:cursor-pointer rb:hover:border-[#155EEF]" onClick={() => handleJump(item.url)}>
<div className="rb:flex rb:justify-between">

View File

@@ -1,3 +1,11 @@
/*
* @Description:
* @Version: 0.0.1
* @Author: yujiangping
* @Date: 2026-01-13 11:44:06
* @LastEditors: yujiangping
* @LastEditTime: 2026-01-15 20:59:57
*/
import React, { useState, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';

View File

@@ -47,7 +47,7 @@ const QuickActions: FC<QuickActionsProps> = ({ onNavigate }) => {
key: 'space-management',
icon: spaceIcon,
title: t('quickActions.spaceManagement'),
onClick: () => onNavigate?.('/spce')
onClick: () => onNavigate?.('/space')
},
// {
// key: 'workflow-orchestration',

View File

@@ -2,7 +2,7 @@
import { useEffect, useState, useRef, useCallback, type FC } from 'react';
import { useNavigate, useParams, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import { Switch, Button, Dropdown, Space, Modal, message, Radio } from 'antd';
import { Switch, Button, Dropdown, Space, Modal, message, Radio, Tooltip } from 'antd';
import type { MenuProps } from 'antd';
import SearchInput from '@/components/SearchInput'
import Table, { type TableRef } from '@/components/Table'
@@ -564,6 +564,37 @@ const Private: FC = () => {
</span>
);
}
},{
title: t('knowledgeBase.processMsg'),
dataIndex: 'progress_msg',
key: 'progress_msg',
width: 320,
render: (value: string) => {
if (!value) return '-';
// 解析日志格式,将 \n 转换为换行
const formattedText = value.replace(/\\n/g, '\n');
return (
<Tooltip title={<pre style={{ margin: 0, whiteSpace: 'pre-wrap' }}>{formattedText}</pre>} placement="topLeft">
<div
style={{
maxWidth: '320px',
overflow: 'hidden',
textOverflow: 'ellipsis',
display: '-webkit-box',
WebkitLineClamp: 2,
WebkitBoxOrient: 'vertical',
lineHeight: '1.5',
whiteSpace: 'pre-wrap',
wordBreak: 'break-word'
}}
>
{formattedText}
</div>
</Tooltip>
);
}
},
{
title: t('knowledgeBase.processingMode'),

View File

@@ -292,7 +292,7 @@ const KnowledgeGraph: FC<KnowledgeGraphProps> = ({ data, loading = false }) => {
if (params.dataType === 'node') {
const node = params.data as KnowledgeNode
return `
<div>
<div class="rb:max-w-[560px]">
<div><strong>${node.entity_name}</strong></div>
<div>类型: ${node.entity_type}</div>
<div>重要度: ${(node.pagerank * 100).toFixed(2)}%</div>
@@ -301,10 +301,10 @@ const KnowledgeGraph: FC<KnowledgeGraphProps> = ({ data, loading = false }) => {
} else if (params.dataType === 'edge') {
const edge = params.data as KnowledgeEdge
return `
<div>
<div class="rb:max-w-[560px]">
<div><strong>关系</strong></div>
<div>权重: ${edge.weight}</div>
<div>${edge.description}</div>
<div class="rb:break-words rb:whitespace-pre-wrap">${edge.description}</div>
</div>
`
}