Merge branch 'develop' into fix/workflow_zy
This commit is contained in:
@@ -11,15 +11,16 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
from app.models.app_model import AppType, App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
@@ -405,6 +406,15 @@ async def draft_run(
|
||||
# 只读操作,允许访问共享应用
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 处理会话ID(创建或验证)
|
||||
conversation_id = await draft_service._ensure_conversation(
|
||||
conversation_id=payload.conversation_id,
|
||||
|
||||
@@ -74,7 +74,7 @@ def get_multi_agent_configs(
|
||||
"app_id": str(app_id),
|
||||
"default_model_config_id": None,
|
||||
"model_parameters": None,
|
||||
"orchestration_mode": "conditional",
|
||||
"orchestration_mode": "supervisor",
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": {
|
||||
|
||||
@@ -466,7 +466,7 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= payload.agent_config,
|
||||
config=agent_config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
@@ -565,11 +565,12 @@ async def chat(
|
||||
config = workflow_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
@@ -601,7 +602,7 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
|
||||
@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
@@ -96,6 +96,7 @@ async def create_workflow_config(
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
@@ -199,10 +200,10 @@ async def create_workflow_config(
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
@@ -243,11 +244,11 @@ async def delete_workflow_config(
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
@@ -312,12 +313,12 @@ async def validate_workflow_config(
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
@@ -365,10 +366,10 @@ async def get_workflow_executions(
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
@@ -417,16 +418,14 @@ async def get_workflow_execution(
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
@@ -487,22 +486,22 @@ async def run_workflow(
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
@@ -554,10 +553,10 @@ async def run_workflow(
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
@@ -602,7 +601,7 @@ async def cancel_workflow_execution(
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
|
||||
@@ -7,17 +7,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
|
||||
# Database configuration (Postgres)
|
||||
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||
@@ -37,7 +38,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||
@@ -48,7 +49,7 @@ class Settings:
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||
|
||||
|
||||
# Xinference configuration
|
||||
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||
|
||||
@@ -57,17 +58,17 @@ class Settings:
|
||||
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||
|
||||
|
||||
# LLM Request Configuration
|
||||
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||
|
||||
|
||||
# JWT Token Configuration
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
@@ -86,19 +87,19 @@ class Settings:
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
# Superuser settings (internal defaults)
|
||||
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||
|
||||
|
||||
# Generic File Upload (internal)
|
||||
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||
@@ -123,7 +124,7 @@ class Settings:
|
||||
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||
|
||||
|
||||
# Sensitive Data Filtering
|
||||
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||
|
||||
@@ -142,7 +143,6 @@ class Settings:
|
||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
@@ -150,15 +150,15 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
@@ -167,7 +167,10 @@ class Settings:
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
@@ -182,7 +185,7 @@ class Settings:
|
||||
if filename:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -110,24 +110,24 @@ HTTP_MAPPING = {
|
||||
BizCode.TOKEN_EXPIRED: 401,
|
||||
BizCode.TOKEN_BLACKLISTED: 401,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 404,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.NOT_FOUND: 404,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 404,
|
||||
BizCode.MODEL_NOT_FOUND: 404,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 404,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 404,
|
||||
BizCode.FILE_NOT_FOUND: 404,
|
||||
BizCode.APP_NOT_FOUND: 404,
|
||||
BizCode.RELEASE_NOT_FOUND: 404,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 400,
|
||||
BizCode.FILE_NOT_FOUND: 400,
|
||||
BizCode.APP_NOT_FOUND: 400,
|
||||
BizCode.RELEASE_NOT_FOUND: 400,
|
||||
BizCode.DUPLICATE_NAME: 409,
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
|
||||
BizCode.AGENT_CONFIG_MISSING: 400,
|
||||
BizCode.SHARE_DISABLED: 403,
|
||||
|
||||
@@ -425,15 +425,9 @@ async def Input_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
@@ -539,31 +533,11 @@ async def Input_Summary(
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
# Return retrieved information directly without LLM processing
|
||||
# Use the raw retrieved info as the answer
|
||||
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
|
||||
@@ -10,9 +10,6 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -213,7 +210,9 @@ class OpenAPISchemaParser:
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
|
||||
summary = operation.get("summary", "")
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
@@ -223,7 +222,7 @@ class OpenAPISchemaParser:
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"summary": summary if summary else operation_id,
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
|
||||
@@ -232,7 +232,7 @@ class LangchainAdapter:
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
@@ -241,7 +241,7 @@ class LangchainAdapter:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["regex"] = param.pattern
|
||||
field_kwargs["pattern"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
|
||||
@@ -27,20 +27,22 @@ class SimpleMCPClient:
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
self.is_sse = "/sse" in server_url.lower()
|
||||
|
||||
# 连接状态
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
self._request_id = 0
|
||||
self._pending_requests = {}
|
||||
self._server_capabilities = {}
|
||||
self._endpoint_url = None # SSE endpoint URL
|
||||
self._sse_task = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
@@ -57,47 +59,154 @@ class SimpleMCPClient:
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 启动消息处理
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
|
||||
# 发送初始化消息
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||
if "modelscope.net" in self.server_url:
|
||||
if self.is_sse:
|
||||
await self._initialize_sse_session()
|
||||
elif "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_sse_session(self):
|
||||
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||
try:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
# 启动 SSE 读取任务
|
||||
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
|
||||
|
||||
# 等待获取 endpoint URL
|
||||
for _ in range(10):
|
||||
if self._endpoint_url:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("未能获取 endpoint URL")
|
||||
|
||||
# 发送 initialize 请求到 endpoint
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
init_response = await self._send_sse_request(init_request)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
result = init_response.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
async def _read_sse_stream(self, response):
|
||||
"""读取 SSE 流"""
|
||||
try:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
|
||||
if line.startswith('event:'):
|
||||
continue
|
||||
|
||||
if line.startswith('data:'):
|
||||
data = line[5:].strip() # 去除 'data:' 后的空格
|
||||
if not data or data == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
# 处理 endpoint 事件(相对路径或绝对路径)
|
||||
if not self._endpoint_url:
|
||||
# 如果是相对路径,拼接成完整 URL
|
||||
if data.startswith('/'):
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
parsed = urlparse(self.server_url)
|
||||
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
|
||||
else:
|
||||
self._endpoint_url = data
|
||||
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
|
||||
continue
|
||||
|
||||
# 处理 message 事件
|
||||
message = json.loads(data)
|
||||
request_id = message.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"SSE 流读取错误: {e}")
|
||||
|
||||
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通过 SSE endpoint 发送请求"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
request_id = request["id"]
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await asyncio.wait_for(future, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError("请求超时")
|
||||
|
||||
async def _send_sse_notification(self, notification: Dict[str, Any]):
|
||||
"""发送通知(无需响应)"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
@@ -107,18 +216,12 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=init_request
|
||||
) as response:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
@@ -127,21 +230,16 @@ class SimpleMCPClient:
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
# 获取 session ID
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
# 发送 initialized 通知
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=initialized_notification
|
||||
) as notif_response:
|
||||
async with self._session.post(self.server_url, json=initialized_notification):
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
@@ -149,12 +247,18 @@ class SimpleMCPClient:
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
# 基础 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
# 合并 connection_config 中的自定义 headers
|
||||
custom_headers = self.connection_config.get("headers", {})
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# 处理认证配置(认证 headers 优先级更高)
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
|
||||
@@ -178,7 +282,7 @@ class SimpleMCPClient:
|
||||
return headers
|
||||
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息"""
|
||||
"""发送初始化消息(WebSocket)"""
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
@@ -186,124 +290,90 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
result = response_data.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
await self._websocket.send(json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}))
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||
|
||||
result = response_data.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response_data.get("result", {})
|
||||
|
||||
def _get_request_id(self) -> int:
|
||||
"""生成请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
# 处理响应
|
||||
if "id" in data:
|
||||
request_id = str(data["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
async for message in self._websocket:
|
||||
data = json.loads(message)
|
||||
request_id = data.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
except ConnectionClosed:
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理异常: {e}")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
request_id = str(request_data["id"])
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
def _get_request_id(self) -> str:
|
||||
"""获取请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
logger.error(f"WebSocket 消息处理错误: {e}")
|
||||
|
||||
@@ -74,6 +74,7 @@ class WorkflowExecutor:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
@@ -114,7 +115,7 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [('user', user_message)],
|
||||
"messages": conversation_messages,
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from operator import add
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AnyMessage, AIMessage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[tuple[str, str]], add]
|
||||
messages: list[dict[str, str]]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
@@ -154,7 +154,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
超时时间
|
||||
"""
|
||||
return 60
|
||||
return settings.WORKFLOW_NODE_TIMEOUT
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
@@ -356,6 +357,7 @@ class BaseNode(ABC):
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"variables": state["variables"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
|
||||
@@ -6,7 +6,6 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -38,7 +37,23 @@ class EndNode(BaseNode):
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
else:
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
])
|
||||
output = "工作流已完成"
|
||||
|
||||
# 统计信息(用于日志)
|
||||
@@ -166,6 +181,12 @@ class EndNode(BaseNode):
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
}
|
||||
])
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
@@ -176,7 +197,6 @@ class EndNode(BaseNode):
|
||||
source_node_id = edge.get("source")
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
print("="*50)
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
@@ -216,12 +236,24 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
logger.info(
|
||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
@@ -258,6 +290,17 @@ class EndNode(BaseNode):
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_output
|
||||
}
|
||||
])
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
@@ -280,7 +323,8 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
|
||||
@@ -11,12 +11,12 @@ class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
default='user',
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
default="",
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@@ -30,6 +30,23 @@ class MessageConfig(BaseModel):
|
||||
return v.lower()
|
||||
|
||||
|
||||
class MemoryWindowSetting(BaseModel):
|
||||
enable: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆"
|
||||
)
|
||||
|
||||
enable_window: bool = Field(
|
||||
default=False,
|
||||
description="启用记忆窗口"
|
||||
)
|
||||
|
||||
window_size: int = Field(
|
||||
default=20,
|
||||
description="记忆窗口大小"
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
@@ -48,6 +65,11 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
memory: MemoryWindowSetting = Field(
|
||||
...,
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -85,28 +85,31 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
messages.append({"role": "system", "content": content})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
# if self.typed_config.memory.enable_window:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
@@ -189,7 +192,7 @@ class LLMNode(BaseNode):
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
"config": {
|
||||
|
||||
@@ -3,8 +3,9 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.db import get_db_read, get_db_context
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
with get_db_context() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
)
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, state),
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -41,6 +41,7 @@ class ToolConfig(BaseModel):
|
||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||
|
||||
|
||||
class ToolOldConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
@@ -348,6 +349,7 @@ class AppChatRequest(BaseModel):
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
"""试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
"""
|
||||
情景记忆的请求和响应模型
|
||||
"""
|
||||
from abc import ABC
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
type_mapping = {
|
||||
"Person": "人物实体节点",
|
||||
"Organization": "组织实体节点",
|
||||
"ORG": "组织实体节点",
|
||||
"Location": "地点实体节点",
|
||||
"LOC": "地点实体节点",
|
||||
"Event": "事件实体节点",
|
||||
"Concept": "概念实体节点",
|
||||
"Time": "时间实体节点",
|
||||
"Position": "职位实体节点",
|
||||
"WorkRole": "职业实体节点",
|
||||
"System": "系统实体节点",
|
||||
"Policy": "政策实体节点",
|
||||
"HistoricalPeriod": "历史时期实体节点",
|
||||
"HistoricalState": "历史国家实体节点",
|
||||
"HistoricalEvent": "历史事件实体节点",
|
||||
"EconomicFactor": "经济因素实体节点",
|
||||
"Condition": "条件实体节点",
|
||||
"Numeric": "数值实体节点"
|
||||
}
|
||||
class EmotionType(ABC):
|
||||
JOY_TYPE = "joy"
|
||||
SURPRISE_TYPE = "surprise"
|
||||
SANDROWNESS_TYPE = "sadness"
|
||||
FEAR_TYPE = "fear"
|
||||
ANGET_TYPE="anger"
|
||||
NEUTRAL_TYPE="neutral"
|
||||
EMOTION_MAPPING={
|
||||
"joy":"愉快",
|
||||
"surprise":"惊喜",
|
||||
"sadness":"悲伤",
|
||||
"fear":"恐惧",
|
||||
"anger":"生气",
|
||||
"neutral":"中性"
|
||||
}
|
||||
class EmotionSubject(ABC):
|
||||
SUBJECT_MAPPING={
|
||||
"self":"自己",
|
||||
"other":"别人",
|
||||
"object":"事物对象"
|
||||
}
|
||||
|
||||
class EpisodicMemoryOverviewRequest(BaseModel):
|
||||
"""情景记忆总览查询请求"""
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
@@ -59,7 +60,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -210,7 +211,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -511,7 +512,6 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("多 Agent 流式聊天被中断")
|
||||
@@ -537,83 +537,19 @@ class AppChatService:
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
|
||||
input_data = {"message":message, "variables": variables,
|
||||
"conversation_id": str(conversation_id)}
|
||||
inconfig = workflow_service.get_workflow_config(app_id)
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = workflow_service.create_execution(
|
||||
workflow_config_id=inconfig.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
conversation_id=str(conversation_id),
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return await workflow_service.run(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
)
|
||||
else:
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def workflow_chat_stream(
|
||||
self,
|
||||
@@ -622,7 +558,7 @@ class AppChatService:
|
||||
config: WorkflowConfig,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
user_id: str = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
@@ -632,62 +568,21 @@ class AppChatService:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
input_data = {"message": message, "variables": variables,
|
||||
"conversation_id": str(conversation_id)}
|
||||
inconfig = workflow_service.get_workflow_config(app_id)
|
||||
# 2. 创建执行记录
|
||||
execution = workflow_service.create_execution(
|
||||
workflow_config_id=inconfig.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
conversation_id=str(conversation_id),
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
):
|
||||
yield event
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 流式执行工作流
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in workflow_service._run_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
workflow_service.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
# 发送错误事件
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution.execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
|
||||
@@ -516,8 +516,16 @@ class ConversationService:
|
||||
|
||||
conversation_messages = self.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=30
|
||||
max_history=20
|
||||
)
|
||||
if len(conversation_messages) == 0:
|
||||
return ConversationOut(
|
||||
theme="",
|
||||
question=[],
|
||||
summary="",
|
||||
takeaways=[],
|
||||
info_score=0,
|
||||
)
|
||||
|
||||
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
@@ -536,6 +544,7 @@ class ConversationService:
|
||||
]
|
||||
logger.info(f"Invoking LLM for conversation_id={conversation_id}")
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
|
||||
try:
|
||||
if isinstance(model_resp.content, str):
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
|
||||
@@ -245,7 +245,8 @@ class DraftRunService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
sub_agent: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""执行试运行(使用 LangChain Agent)
|
||||
|
||||
@@ -435,7 +436,7 @@ class DraftRunService:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 8. 保存会话消息
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from threading import Lock
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
@@ -83,12 +81,7 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
def get_group_lock(self, group_id: str) -> Lock:
|
||||
"""Get lock for specific group to prevent concurrent processing"""
|
||||
with self.locks_lock:
|
||||
if group_id not in self.user_locks:
|
||||
self.user_locks[group_id] = Lock()
|
||||
return self.user_locks[group_id]
|
||||
|
||||
|
||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||
"""Extract tool call information from event"""
|
||||
@@ -417,241 +410,236 @@ class MemoryAgentService:
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
# Get group lock to prevent concurrent processing
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
async with client.session('data_flow') as session:
|
||||
session_start = time.time()
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
|
||||
async with client.session('data_flow') as session:
|
||||
session_start = time.time()
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
|
||||
tools_start = time.time()
|
||||
tools = await load_mcp_tools(session)
|
||||
tools_time = time.time() - tools_start
|
||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||
|
||||
outputs = []
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
tools_start = time.time()
|
||||
tools = await load_mcp_tools(session)
|
||||
tools_time = time.time() - tools_start
|
||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||
|
||||
# Pass memory_config to the graph workflow
|
||||
graph_start = time.time()
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
graph_init_time = time.time() - graph_start
|
||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
event_count = 0
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
event_count += 1
|
||||
event_start = time.time()
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
outputs = []
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
# Pass memory_config to the graph workflow
|
||||
graph_start = time.time()
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
graph_init_time = time.time() - graph_start
|
||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
event_count = 0
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
event_count += 1
|
||||
event_start = time.time()
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
event_time = time.time() - event_start
|
||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
event_time = time.time() - event_start
|
||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||
workflow_duration = time.time() - start
|
||||
session_duration = time.time() - session_start
|
||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
session_duration = time.time() - session_start
|
||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
try:
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=True,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer)
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
retrieved_content=[]
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch)!="2":
|
||||
for intermediate in intermediate_outputs:
|
||||
print(intermediate)
|
||||
intermediate_type=intermediate['type']
|
||||
if intermediate_type=="search_result":
|
||||
query=intermediate['query']
|
||||
raw_results=intermediate['raw_results']
|
||||
reranked_results=raw_results.get('reranked_results',[])
|
||||
try:
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements=[]
|
||||
statements=list(set(statements))
|
||||
retrieved_content.append({query:statements})
|
||||
if retrieved_content==[]:
|
||||
retrieved_content=''
|
||||
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=final_answer,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
print("写入成功")
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=True,
|
||||
duration=total_duration,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer)
|
||||
}
|
||||
)
|
||||
retrieved_content=[]
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch)!="2":
|
||||
for intermediate in intermediate_outputs:
|
||||
print(intermediate)
|
||||
intermediate_type=intermediate['type']
|
||||
if intermediate_type=="search_result":
|
||||
query=intermediate['query']
|
||||
raw_results=intermediate['raw_results']
|
||||
reranked_results=raw_results.get('reranked_results',[])
|
||||
try:
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements=[]
|
||||
statements=list(set(statements))
|
||||
retrieved_content.append({query:statements})
|
||||
if retrieved_content==[]:
|
||||
retrieved_content=''
|
||||
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=final_answer,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
print("写入成功")
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"answer": final_answer,
|
||||
"intermediate_outputs": intermediate_outputs
|
||||
}
|
||||
|
||||
return {
|
||||
"answer": final_answer,
|
||||
"intermediate_outputs": intermediate_outputs
|
||||
}
|
||||
|
||||
def _create_intermediate_key(self, output: Dict) -> str:
|
||||
"""
|
||||
Create a unique key for an intermediate output to detect duplicates.
|
||||
|
||||
@@ -15,6 +15,8 @@ from neo4j.time import DateTime as Neo4jDateTime
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.memory_episodic_schema import EmotionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MemoryEntityService:
|
||||
@@ -123,7 +125,7 @@ class MemoryEntityService:
|
||||
extracted_entity_list = self._deduplicate_dict_list(extracted_entity_list)
|
||||
|
||||
# 合并所有数据并处理相同text的合并
|
||||
all_timeline_data = memory_summary_list + statement_list + extracted_entity_list
|
||||
all_timeline_data = memory_summary_list + statement_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
result = {
|
||||
@@ -496,11 +498,11 @@ class MemoryEmotion:
|
||||
length_data.append(emotion_intensity)
|
||||
if emotion_type is not None and emotion_intensity is not None and formatted_created_at is not None:
|
||||
# 使用(emotion_type, created_at)作为分组键
|
||||
if emotion_type in {"joy", "surprise"}:
|
||||
if emotion_type in {EmotionType.JOY_TYPE, EmotionType.SURPRISE_TYPE}:
|
||||
emotion_type='positive'
|
||||
elif emotion_type in {"sadness", "fear", "anger"}:
|
||||
elif emotion_type in {EmotionType.SANDROWNESS_TYPE, EmotionType.FEAR_TYPE, EmotionType.ANGET_TYPE}:
|
||||
emotion_type='negative'
|
||||
elif emotion_type=='neutral':
|
||||
elif emotion_type==EmotionType.NEUTRAL_TYPE:
|
||||
emotion_type='neutral'
|
||||
group_key = (emotion_type, formatted_created_at)
|
||||
# 累加emotion_intensity
|
||||
@@ -595,7 +597,7 @@ class MemoryInteraction:
|
||||
group_id = ori_data[0]['group_id']
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
||||
if not Space_User:
|
||||
return '不存在用户'
|
||||
return []
|
||||
user_id=Space_User[0]['id']
|
||||
|
||||
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)
|
||||
|
||||
@@ -267,14 +267,14 @@ class MemoryForgetService:
|
||||
elif node_type_label == 'memorysummary':
|
||||
node_type_label = 'summary'
|
||||
|
||||
# 将 Neo4j DateTime 对象转换为时间戳
|
||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||
last_access_time = result['last_access_time']
|
||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||
# 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差
|
||||
if last_access_dt:
|
||||
if last_access_dt.tzinfo is None:
|
||||
last_access_dt = last_access_dt.replace(tzinfo=timezone.utc)
|
||||
last_access_timestamp = int(last_access_dt.timestamp())
|
||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||
else:
|
||||
last_access_timestamp = 0
|
||||
|
||||
@@ -520,7 +520,7 @@ class MemoryForgetService:
|
||||
'average_activation_value': result['average_activation'],
|
||||
'low_activation_nodes': result['low_activation_nodes'] or 0,
|
||||
'forgetting_threshold': forgetting_threshold,
|
||||
'timestamp': int(datetime.now().timestamp())
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
else:
|
||||
activation_metrics = {
|
||||
@@ -530,7 +530,7 @@ class MemoryForgetService:
|
||||
'average_activation_value': None,
|
||||
'low_activation_nodes': 0,
|
||||
'forgetting_threshold': forgetting_threshold,
|
||||
'timestamp': int(datetime.now().timestamp())
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
# 收集节点类型分布
|
||||
@@ -620,7 +620,7 @@ class MemoryForgetService:
|
||||
'merged_count': record.merged_count,
|
||||
'average_activation': record.average_activation_value,
|
||||
'total_nodes': record.total_nodes,
|
||||
'execution_time': int(record.execution_time.timestamp())
|
||||
'execution_time': int(record.execution_time.timestamp() * 1000)
|
||||
})
|
||||
|
||||
api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据")
|
||||
@@ -661,7 +661,7 @@ class MemoryForgetService:
|
||||
'node_distribution': node_distribution,
|
||||
'recent_trends': recent_trends,
|
||||
'pending_nodes': pending_nodes,
|
||||
'timestamp': int(datetime.now().timestamp())
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
|
||||
@@ -1327,7 +1327,8 @@ class MultiAgentOrchestrator:
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
sub_agent=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -13,10 +13,15 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.conversation_repository import ConversationRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -1196,18 +1201,17 @@ async def analytics_memory_types(
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
统计9种记忆类型的数量和百分比
|
||||
统计8种记忆类型的数量和百分比
|
||||
|
||||
计算规则:
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
9. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
|
||||
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
|
||||
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
|
||||
5. 隐性记忆 (IMPLICIT_MEMORY) = Statement 节点数量的三分之一
|
||||
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
|
||||
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary(通过 MemoryBaseService.get_episodic_memory_count 获取)
|
||||
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
@@ -1227,7 +1231,6 @@ async def analytics_memory_types(
|
||||
- PERCEPTUAL_MEMORY: 感知记忆
|
||||
- WORKING_MEMORY: 工作记忆
|
||||
- SHORT_TERM_MEMORY: 短期记忆
|
||||
- LONG_TERM_MEMORY: 长期记忆
|
||||
- EXPLICIT_MEMORY: 显性记忆
|
||||
- IMPLICIT_MEMORY: 隐性记忆
|
||||
- EMOTIONAL_MEMORY: 情绪记忆
|
||||
@@ -1237,40 +1240,78 @@ async def analytics_memory_types(
|
||||
# 初始化基础服务
|
||||
base_service = MemoryBaseService()
|
||||
|
||||
# 定义需要查询的基础节点类型
|
||||
node_types = {
|
||||
"Statement": "Statement",
|
||||
"Entity": "ExtractedEntity",
|
||||
"Chunk": "Chunk"
|
||||
}
|
||||
# 初始化感知记忆服务
|
||||
perceptual_service = MemoryPerceptualService(db)
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
node_counts = {}
|
||||
# 获取感知记忆数量
|
||||
if end_user_id:
|
||||
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
|
||||
perceptual_count = perceptual_stats.get("total", 0)
|
||||
else:
|
||||
perceptual_count = 0
|
||||
|
||||
# 查询每种节点类型的数量
|
||||
for key, node_type in node_types.items():
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
# 获取工作记忆数量(基于会话数量)
|
||||
work_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
conversation_repo = ConversationRepository(db)
|
||||
conversations = conversation_repo.get_conversation_by_user_id(
|
||||
user_id=uuid.UUID(end_user_id),
|
||||
limit=100, # 获取更多会话以准确统计
|
||||
is_activate=True
|
||||
)
|
||||
work_count = len(conversations)
|
||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||
work_count = 0
|
||||
|
||||
# 获取隐性记忆数量(基于 Statement 节点数量的三分之一)
|
||||
implicit_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
# 查询 Statement 节点数量
|
||||
query = """
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query)
|
||||
|
||||
# 提取计数结果
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
node_counts[key] = count
|
||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
# 取三分之一作为隐性记忆数量
|
||||
implicit_count = round(statement_count / 3)
|
||||
logger.debug(f"隐性记忆数量(Statement数量的1/3): {implicit_count} (Statement总数={statement_count}, end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取Statement数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
implicit_count = 0
|
||||
|
||||
# 获取各节点类型的数量
|
||||
statement_count = node_counts.get("Statement", 0)
|
||||
entity_count = node_counts.get("Entity", 0)
|
||||
chunk_count = node_counts.get("Chunk", 0)
|
||||
# 原有的基于行为习惯的统计方式(已注释)
|
||||
# implicit_count = 0
|
||||
# if end_user_id:
|
||||
# try:
|
||||
# implicit_service = ImplicitMemoryService(db, end_user_id)
|
||||
# behavior_habits = await implicit_service.get_behavior_habits(
|
||||
# user_id=end_user_id
|
||||
# )
|
||||
# implicit_count = len(behavior_habits)
|
||||
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"获取行为习惯数量失败,隐性记忆数量设为0: {str(e)}")
|
||||
# implicit_count = 0
|
||||
|
||||
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
|
||||
short_term_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
short_term_service = ShortService(end_user_id)
|
||||
short_term_data = short_term_service.get_short_databasets()
|
||||
# 统计 short_term 数组的长度
|
||||
if short_term_data:
|
||||
short_term_count = len(short_term_data)
|
||||
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取短期记忆数量失败,短期记忆数量设为0: {str(e)}")
|
||||
short_term_count = 0
|
||||
|
||||
# 获取用户的遗忘阈值配置
|
||||
forgetting_threshold = 0.3 # 默认值
|
||||
@@ -1296,17 +1337,16 @@ async def analytics_memory_types(
|
||||
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
|
||||
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
|
||||
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, statement_count)
|
||||
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
|
||||
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
|
||||
|
||||
# 按规则计算9种记忆类型的数量(使用英文枚举作为key)
|
||||
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||
memory_counts = {
|
||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
||||
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
|
||||
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
|
||||
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
|
||||
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
|
||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
||||
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆(Statement数量的1/3)
|
||||
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
|
||||
"EPISODIC_MEMORY": episodic_count, # 情景记忆
|
||||
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
|
||||
@@ -1332,7 +1372,7 @@ async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
node_types: Optional[List[str]] = None,
|
||||
limit: int = 100,
|
||||
limit: int = 130,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -1416,12 +1456,14 @@ async def analytics_graph_data(
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
properties(n) as properties
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
# "limit": limit
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
# 执行节点查询
|
||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||
|
||||
@@ -1576,10 +1618,15 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
for field in allowed_fields:
|
||||
if field in properties:
|
||||
value = properties[field]
|
||||
if str(field) == 'entity_type':
|
||||
value=type_mapping.get(value,'')
|
||||
if str(field)=="emotion_type":
|
||||
value=EmotionType.EMOTION_MAPPING.get(value)
|
||||
if str(field)=="emotion_subject":
|
||||
value=EmotionSubject.SUBJECT_MAPPING.get(value)
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
|
||||
print(filtered_props)
|
||||
return filtered_props
|
||||
|
||||
|
||||
|
||||
@@ -2,29 +2,28 @@
|
||||
工作流服务层
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any, Annotated, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db, get_db_context
|
||||
from app.db import get_db
|
||||
from app.models.conversation_model import Message
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.repositories.conversation_repository import MessageRepository
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,6 +36,7 @@ class WorkflowService:
|
||||
self.config_repo = WorkflowConfigRepository(db)
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
self.message_repo = MessageRepository(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
@@ -418,14 +418,13 @@ class WorkflowService:
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
workspace_id:
|
||||
config:
|
||||
payload:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
执行结果(非流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
@@ -438,7 +437,8 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
@@ -461,7 +461,7 @@ class WorkflowService:
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
@@ -482,14 +482,6 @@ class WorkflowService:
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
with get_db_context() as db:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=payload.user_id,
|
||||
original_user_id=payload.user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
@@ -500,14 +492,17 @@ class WorkflowService:
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=end_user_id
|
||||
user_id=payload.user_id
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
@@ -517,6 +512,17 @@ class WorkflowService:
|
||||
"completed",
|
||||
output_data=result
|
||||
)
|
||||
final_messages = result.get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
@@ -529,6 +535,7 @@ class WorkflowService:
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
@@ -559,6 +566,7 @@ class WorkflowService:
|
||||
"""运行工作流(流式)
|
||||
|
||||
Args:
|
||||
workspace_id:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
config: 存储类型(可选)
|
||||
@@ -601,7 +609,7 @@ class WorkflowService:
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
triggered_by=None,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
@@ -621,14 +629,6 @@ class WorkflowService:
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
with get_db_context() as db:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=payload.user_id,
|
||||
original_user_id=payload.user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
@@ -638,17 +638,46 @@ class WorkflowService:
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=end_user_id
|
||||
user_id=payload.user_id
|
||||
):
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||
for message in final_messages:
|
||||
message_obj = Message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
)
|
||||
self.message_repo.add_message(message_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
@@ -667,6 +696,8 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
@deprecated(reason="This method is deprecated. "
|
||||
"Please use WorkflowService.run / run_stream instead.")
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -819,6 +850,7 @@ class WorkflowService:
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
|
||||
@@ -8,9 +8,11 @@ import uuid
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from app.db import get_db_read
|
||||
from app.models import AppRelease, WorkflowConfig
|
||||
from app.models.agent_app_config_model import AgentConfig
|
||||
from app.models.multi_agent_model import MultiAgentConfig
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
|
||||
|
||||
def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]:
|
||||
@@ -24,18 +26,18 @@ def model_parameters_to_dict(model_parameters: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
if model_parameters is None:
|
||||
return None
|
||||
|
||||
|
||||
if isinstance(model_parameters, dict):
|
||||
return model_parameters
|
||||
|
||||
|
||||
# Pydantic v2
|
||||
if hasattr(model_parameters, 'model_dump'):
|
||||
return model_parameters.model_dump()
|
||||
|
||||
|
||||
# Pydantic v1
|
||||
if hasattr(model_parameters, 'dict'):
|
||||
return model_parameters.dict()
|
||||
|
||||
|
||||
# 其他情况尝试转换
|
||||
try:
|
||||
return dict(model_parameters)
|
||||
@@ -54,17 +56,18 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]:
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
|
||||
from app.schemas import ModelParameters
|
||||
|
||||
|
||||
if isinstance(data, ModelParameters):
|
||||
return data
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
return ModelParameters(**data)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class AgentConfigProxy:
|
||||
"""Proxy class for AgentConfig (legacy compatibility)"""
|
||||
|
||||
@@ -78,8 +81,7 @@ class AgentConfigProxy:
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
|
||||
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
|
||||
|
||||
def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
|
||||
config_dict = release.config
|
||||
|
||||
agent_config = AgentConfig(
|
||||
@@ -95,18 +97,17 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
|
||||
|
||||
return agent_config
|
||||
|
||||
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
|
||||
|
||||
def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig:
|
||||
config_dict = release.config
|
||||
|
||||
|
||||
agent_config = MultiAgentConfig(
|
||||
app_id=release.app_id,
|
||||
default_model_config_id=release.default_model_config_id,
|
||||
model_parameters=config_dict.get("model_parameters"),
|
||||
master_agent_id=config_dict.get("master_agent_id"),
|
||||
master_agent_name=config_dict.get("master_agent_name"),
|
||||
orchestration_mode=config_dict.get("orchestration_mode", "conditional"),
|
||||
orchestration_mode=config_dict.get("orchestration_mode", "supervisor"),
|
||||
sub_agents=config_dict.get("sub_agents", []),
|
||||
routing_rules=config_dict.get("routing_rules"),
|
||||
execution_config=config_dict.get("execution_config", {}),
|
||||
@@ -116,24 +117,26 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
|
||||
|
||||
return agent_config
|
||||
|
||||
def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig:
|
||||
|
||||
def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig:
|
||||
config_dict = release.config
|
||||
|
||||
with get_db_read() as db:
|
||||
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
|
||||
source_config_id = source_config.id
|
||||
|
||||
config = WorkflowConfig(
|
||||
id=release.id,
|
||||
id=source_config_id,
|
||||
app_id=release.app_id,
|
||||
nodes=config_dict.get("nodes", []),
|
||||
edges=config_dict.get("edges", []),
|
||||
variables=config_dict.get("variables", []),
|
||||
execution_config=config_dict.get("execution_config", {}),
|
||||
triggers=config_dict.get("triggers", [])
|
||||
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
|
||||
"""Convert dict to MultiAgentConfig model object
|
||||
|
||||
@@ -149,7 +152,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui
|
||||
... "app_id": "uuid-here",
|
||||
... "master_agent_id": "master-uuid",
|
||||
... "master_agent_name": "Master Agent",
|
||||
... "orchestration_mode": "conditional",
|
||||
... "orchestration_mode": "supervisor",
|
||||
... "sub_agents": [
|
||||
... {"agent_id": "sub1-uuid", "name": "Sub Agent 1", "role": "specialist", "priority": 1},
|
||||
... {"agent_id": "sub2-uuid", "name": "Sub Agent 2", "role": "specialist", "priority": 2}
|
||||
@@ -186,7 +189,7 @@ def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uui
|
||||
app_id=final_app_id,
|
||||
master_agent_id=master_agent_id,
|
||||
master_agent_name=config_dict.get("master_agent_name"),
|
||||
orchestration_mode=config_dict.get("orchestration_mode", "conditional"),
|
||||
orchestration_mode=config_dict.get("orchestration_mode", "supervisor"),
|
||||
sub_agents=config_dict.get("sub_agents", []),
|
||||
routing_rules=config_dict.get("routing_rules"),
|
||||
execution_config=config_dict.get("execution_config", {}),
|
||||
@@ -276,7 +279,8 @@ def agent_config_to_dict(agent_config) -> Dict[str, Any]:
|
||||
"id": str(agent_config.id),
|
||||
"app_id": str(agent_config.app_id),
|
||||
"system_prompt": agent_config.system_prompt,
|
||||
"default_model_config_id": str(agent_config.default_model_config_id) if agent_config.default_model_config_id else None,
|
||||
"default_model_config_id": str(
|
||||
agent_config.default_model_config_id) if agent_config.default_model_config_id else None,
|
||||
"model_parameters": agent_config.model_parameters,
|
||||
"knowledge_retrieval": agent_config.knowledge_retrieval,
|
||||
"memory": agent_config.memory,
|
||||
@@ -338,6 +342,3 @@ def workflow_config_to_dict(workflow_config) -> Dict[str, Any]:
|
||||
"created_at": workflow_config.created_at.isoformat() if workflow_config.created_at else None,
|
||||
"updated_at": workflow_config.updated_at.isoformat() if workflow_config.updated_at else None
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -136,7 +136,8 @@ dependencies = [
|
||||
"markdown-to-json==2.1.1",
|
||||
"valkey==6.0.2",
|
||||
"python-calamine>=0.4.0",
|
||||
"xlrd==2.0.2"
|
||||
"xlrd==2.0.2",
|
||||
"deprecated>=1.3.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
BIN
web/src/assets/images/empty/chatEmpty.png
Normal file
BIN
web/src/assets/images/empty/chatEmpty.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 185 KiB |
14
web/src/assets/images/menu/helpCenter.svg
Normal file
14
web/src/assets/images/menu/helpCenter.svg
Normal file
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>使用帮助备份</title>
|
||||
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="首页" transform="translate(-51, -358)" stroke="#5F6266">
|
||||
<g id="使用帮助备份" transform="translate(51, 358)">
|
||||
<g id="编组-35" transform="translate(2, 1.5)">
|
||||
<path d="M6.13163525,1.97938144 L10.3064533,1.97938144 C11.2417733,1.97938144 12,2.70634106 12,3.6030912 L12,10.3762902 C12,11.2730404 11.2417733,12 10.3064533,12 L1.69354673,12 C0.758226699,12 0,11.2730404 0,10.3762902 L0,3.6030912 C0,2.70634106 0.758226699,1.97938144 1.69354673,1.97938144 L2.02448435,1.97938144 L2.02448435,1.97938144" id="路径"></path>
|
||||
<path d="M3.52033177,0.78470905 L6.09032258,1.97938144 L6.09032258,1.97938144 L6.09032258,11.8762887 L2.51918436,10.2162282 C2.10022604,10.0214734 1.83225806,9.6014016 1.83225806,9.13938916 L1.83225806,1.86154804 C1.83225806,1.2057099 2.36391992,0.674048044 3.01975806,0.674048044 C3.19268295,0.674048044 3.36352144,0.711815028 3.52033177,0.78470905 Z" id="矩形" stroke-linejoin="round"></path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
14
web/src/assets/images/menu/helpCenter_active.svg
Normal file
14
web/src/assets/images/menu/helpCenter_active.svg
Normal file
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>使用帮助</title>
|
||||
<g id="v0.2.0" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="首页" transform="translate(-24, -358)" stroke="#212332">
|
||||
<g id="使用帮助" transform="translate(24, 358)">
|
||||
<g id="编组-35" transform="translate(2, 1.5)">
|
||||
<path d="M6.13163525,1.97938144 L10.3064533,1.97938144 C11.2417733,1.97938144 12,2.70634106 12,3.6030912 L12,10.3762902 C12,11.2730404 11.2417733,12 10.3064533,12 L1.69354673,12 C0.758226699,12 0,11.2730404 0,10.3762902 L0,3.6030912 C0,2.70634106 0.758226699,1.97938144 1.69354673,1.97938144 L2.02448435,1.97938144 L2.02448435,1.97938144" id="路径"></path>
|
||||
<path d="M3.52033177,0.78470905 L6.09032258,1.97938144 L6.09032258,1.97938144 L6.09032258,11.8762887 L2.51918436,10.2162282 C2.10022604,10.0214734 1.83225806,9.6014016 1.83225806,9.13938916 L1.83225806,1.86154804 C1.83225806,1.2057099 2.36391992,0.674048044 3.01975806,0.674048044 C3.19268295,0.674048044 3.36352144,0.711815028 3.52033177,0.78470905 Z" id="矩形" stroke-linejoin="round"></path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
@@ -55,7 +55,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
</div>
|
||||
}
|
||||
{/* 消息气泡框 */}
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-100 rb:wrap-break-word', contentClassNames, {
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
|
||||
// 错误消息样式(内容为null且非助手消息)
|
||||
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
|
||||
// 助手消息样式
|
||||
@@ -68,7 +68,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
</div>
|
||||
{/* 底部标签(如时间戳、用户名等) */}
|
||||
{labelPosition === 'bottom' &&
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular">
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4 rb:font-regular rb:mt-2">
|
||||
{labelFormat(item)}
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ export const en = {
|
||||
stepTwoDescription: 'Here you can create and manage spaces to organize models and data for different use cases.Once your spaces are ready, head to User Management to invite members and manage access.👉 Click User Management in the left menu to continue.',
|
||||
stepThree: 'This is User Management',
|
||||
stepThreeDescription: 'Here you can create users, assign roles, and manage access for your team.Once users are set up, the basic configuration is complete and you’re ready to start using the platform 🎉',
|
||||
finishButtonText: 'Get Started',
|
||||
},
|
||||
menu: {
|
||||
home: 'Home',
|
||||
@@ -91,6 +92,7 @@ export const en = {
|
||||
memberManagement: 'Member Management',
|
||||
memorySummary: 'Memory Summary',
|
||||
memoryConversation: 'Memory Validation',
|
||||
helpCenter: 'Help Center',
|
||||
memorySummaryHandlers: 'Memory Summary Handlers',
|
||||
createMemorySummary: 'Create Memory Summary',
|
||||
memoryManagement: 'Memory Management',
|
||||
@@ -183,14 +185,15 @@ export const en = {
|
||||
createNewMemorySummary: 'Create New Memory Entry',
|
||||
|
||||
createNewApplication: 'Create New Application',
|
||||
createNewApplicationDesc: 'Create a new application for this space',
|
||||
createNewApplicationDesc: 'Build an app in just 3 minutes with zero-code drag-and-drop.',
|
||||
|
||||
createNewKnowledge: 'Create New Knowledge',
|
||||
createNewKnowledgeDesc: 'Create a new memory entry',
|
||||
createNewKnowledgeDesc: 'Transform your data into a fully searchable, dedicated knowledge base in seconds.',
|
||||
|
||||
memoryConversation: 'Memory Conversation',
|
||||
memoryConversationDesc: 'Create a new memory conversation',
|
||||
|
||||
memoryConversationDesc: 'The more you use it, the better AI understands you.',
|
||||
helpCenter: 'Help Center',
|
||||
helpCenterDesc: 'One-stop support to answer your questions and get you started fast.',
|
||||
memorySummary: 'View Memory Summary',
|
||||
memorySummaryDesc: 'View Memory Summary Report',
|
||||
|
||||
@@ -618,6 +621,7 @@ export const en = {
|
||||
retrieve:'Retrieve',
|
||||
processing: 'Processing',
|
||||
processingMode: 'Processing Mode',
|
||||
processMsg: 'Processing Message',
|
||||
dataSize: 'Data Size',
|
||||
createUpdateTime: 'Create/Update Time',
|
||||
operation: 'Operation',
|
||||
@@ -1449,6 +1453,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
},
|
||||
memoryConversation: {
|
||||
searchPlaceholder: 'Enter user ID...',
|
||||
chatEmpty:'Is there anything I can help you with?',
|
||||
userID: 'User ID',
|
||||
testMemoryConversation: 'Test Memory Conversation',
|
||||
conversationContent: 'Conversation Content',
|
||||
|
||||
@@ -71,6 +71,7 @@ export const zh = {
|
||||
stepTwoDescription: '你可以在这里创建和管理不同的空间,把模型和数据组织到具体的使用场景中。空间创建完成后,可以去 User Management 邀请成员、分配权限,一起协作使用。👉 点击左侧 User Management 继续。',
|
||||
stepThree: '这里是用户管理页',
|
||||
stepThreeDescription: '你可以在这里创建用户、分配角色,并管理团队成员的访问权限。完成用户设置后,基础配置就准备好了,可以开始实际使用平台的各项功能了 🎉',
|
||||
finishButtonText: '开始使用',
|
||||
},
|
||||
menu: {
|
||||
home: '首页',
|
||||
@@ -782,14 +783,15 @@ export const zh = {
|
||||
createNewMemorySummary: '创建新记忆条目',
|
||||
|
||||
createNewApplication: '创建新应用',
|
||||
createNewApplicationDesc: '创建新空间应用',
|
||||
createNewApplicationDesc: '零代码拖拽3分钟创应用',
|
||||
|
||||
createNewKnowledge: '创建新知识',
|
||||
createNewKnowledgeDesc: '创建新记忆条目',
|
||||
createNewKnowledge: '创建知识库',
|
||||
createNewKnowledgeDesc: '秒变可搜索的专属知识库',
|
||||
|
||||
memoryConversation: '记忆对话',
|
||||
memoryConversationDesc: '记忆对话',
|
||||
|
||||
memoryConversationDesc: '让AI越用越懂你',
|
||||
helpCenter: '帮助中心',
|
||||
helpCenterDesc: '一站式解决疑问快速上手',
|
||||
memorySummary: '查看记忆摘要',
|
||||
memorySummaryDesc: '查看记忆摘要报告',
|
||||
|
||||
@@ -1524,6 +1526,7 @@ export const zh = {
|
||||
deduplication_desc: '去重消歧完成,最终{{count}}个唯一实体'
|
||||
},
|
||||
memoryConversation: {
|
||||
chatEmpty:'有什么我可以帮您的吗?',
|
||||
searchPlaceholder: '输入用户ID...',
|
||||
userID: '用户ID',
|
||||
testMemoryConversation: '测试记忆对话',
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
/*
|
||||
* @Description:
|
||||
* @Version: 0.0.1
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-01-05 17:22:23
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-01-15 21:02:43
|
||||
*/
|
||||
import { create } from 'zustand'
|
||||
import enUS from 'antd/locale/en_US';
|
||||
import zhCN from 'antd/locale/zh_CN';
|
||||
@@ -12,6 +20,28 @@ import { timezoneToAntdLocaleMap } from '@/utils/timezones';
|
||||
dayjs.extend(utc);
|
||||
dayjs.extend(timezone);
|
||||
|
||||
// 自定义中文 locale,修改 Tour 组件的按钮文字
|
||||
const customZhCN: Locale = {
|
||||
...zhCN,
|
||||
Tour: {
|
||||
...zhCN.Tour,
|
||||
Next: '下一步',
|
||||
Previous: '上一步',
|
||||
Finish: '立即体验',
|
||||
},
|
||||
};
|
||||
|
||||
// 自定义英文 locale,修改 Tour 组件的按钮文字
|
||||
const customEnUS: Locale = {
|
||||
...enUS,
|
||||
Tour: {
|
||||
...enUS.Tour,
|
||||
Next: 'Next',
|
||||
Previous: 'Previous',
|
||||
Finish: 'Try it now',
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
interface I18nState {
|
||||
language: string;
|
||||
@@ -23,7 +53,7 @@ interface I18nState {
|
||||
|
||||
const initialTimeZone = localStorage.getItem('timeZone') || 'Asia/Shanghai'
|
||||
const initialLanguage = localStorage.getItem('language') || 'en'
|
||||
const initialLocale = initialLanguage === 'en' ? enUS : zhCN
|
||||
const initialLocale = initialLanguage === 'en' ? customEnUS : customZhCN
|
||||
i18n.changeLanguage(initialLanguage)
|
||||
|
||||
export const useI18n = create<I18nState>((set, get) => ({
|
||||
@@ -32,7 +62,7 @@ export const useI18n = create<I18nState>((set, get) => ({
|
||||
timeZone: initialTimeZone,
|
||||
changeLanguage: (language: string) => {
|
||||
i18n.changeLanguage(language)
|
||||
const localeName = timezoneToAntdLocaleMap[language] || enUS;
|
||||
const localeName = language === 'en' ? customEnUS : customZhCN;
|
||||
set({ language: language, locale: localeName })
|
||||
},
|
||||
changeTimeZone: (timeZone: string) => {
|
||||
|
||||
@@ -11,6 +11,7 @@ import Empty from '@/components/Empty'
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import { randomString } from '@/utils/common'
|
||||
import BgImg from '@/assets/images/conversation/bg.png'
|
||||
import ChatEmpty from '@/assets/images/empty/chatEmpty.png'
|
||||
import Chat from '@/components/Chat'
|
||||
import type { ChatItem } from '@/components/Chat/types'
|
||||
import ButtonCheckbox from '@/components/ButtonCheckbox'
|
||||
@@ -259,9 +260,10 @@ const Conversation: FC = () => {
|
||||
</div>
|
||||
|
||||
<div className="rb:relative rb:h-screen rb:px-4 rb:flex-[1_1_auto]">
|
||||
<div className='rb:w-[760px] rb:h-screen rb:mx-auto rb:pt-10'>
|
||||
<Chat
|
||||
empty={<Empty url={AnalysisEmptyIcon} className="rb:h-full" subTitle={t('memoryConversation.emptyDesc')} />}
|
||||
contentClassName="rb:h-[calc(100%-152px)]"
|
||||
empty={<Empty url={ChatEmpty} className="rb:h-full" size={[320,180]} title={t('memoryConversation.chatEmpty')} subTitle={t('memoryConversation.emptyDesc')} />}
|
||||
contentClassName="rb:h-[calc(100%-152px)] "
|
||||
data={chatList}
|
||||
streamLoading={streamLoading}
|
||||
loading={loading}
|
||||
@@ -290,6 +292,7 @@ const Conversation: FC = () => {
|
||||
</Flex>
|
||||
</Form>
|
||||
</Chat>
|
||||
</div>
|
||||
</div>
|
||||
</Flex>
|
||||
)
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
/*
|
||||
* @Description:
|
||||
* @Version: 0.0.1
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-01-05 17:22:23
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-01-15 14:55:51
|
||||
*/
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
@@ -5,33 +13,49 @@ import Card from './Card';
|
||||
import applicationIcon from '@/assets/images/menu/application_active.svg';
|
||||
import knowledgeIcon from '@/assets/images/menu/knowledge_active.svg';
|
||||
import memoryConversationIcon from '@/assets/images/menu/memoryConversation_active.svg';
|
||||
import helpCenterIcon from '@/assets/images/menu/helpCenter_active.svg'
|
||||
import arrowTopRight from '@/assets/images/home/arrow_top_right.svg';
|
||||
|
||||
const quickOperations = [
|
||||
{ key: 'createNewApplication', url: '/application' },
|
||||
{ key: 'createNewKnowledge', url: '/knowledge-base' },
|
||||
{ key: 'memoryConversation', url: '/memory-conversation' },
|
||||
{ key: 'helpCenter', url: '' },
|
||||
]
|
||||
|
||||
const quickOperationIcons: {[key: string]: string | undefined} = {
|
||||
createNewApplication: applicationIcon,
|
||||
createNewKnowledge: knowledgeIcon,
|
||||
memoryConversation: memoryConversationIcon,
|
||||
helpCenter: helpCenterIcon
|
||||
}
|
||||
const QuickOperation:FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { t, i18n } = useTranslation()
|
||||
const navigate = useNavigate();
|
||||
|
||||
const handleJump = (url: string | null) => {
|
||||
if (url) {
|
||||
navigate(url)
|
||||
}else{
|
||||
const currentLang = i18n.language;
|
||||
const lang = currentLang === 'zh' ? 'zh' : 'en';
|
||||
const helpUrl = `https://docs.redbearai.com/s/${lang}-memorybear`;
|
||||
|
||||
// 创建隐藏的 a 标签来避免弹窗拦截
|
||||
const link = document.createElement('a');
|
||||
link.href = helpUrl;
|
||||
link.target = '_blank';
|
||||
link.rel = 'noopener noreferrer';
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
}
|
||||
}
|
||||
return (
|
||||
<Card
|
||||
title={t('dashboard.quickOperation')}
|
||||
>
|
||||
<div className="rb:grid rb:grid-cols-3 rb:gap-[16px]">
|
||||
<div className="rb:grid rb:grid-cols-4 rb:gap-[16px]">
|
||||
{quickOperations.map(item => (
|
||||
<div key={item.key} className="rb:rounded-[8px] rb:p-[20px_16px] rb:border-1 rb:border-[#DFE4ED] rb:cursor-pointer rb:hover:border-[#155EEF]" onClick={() => handleJump(item.url)}>
|
||||
<div className="rb:flex rb:justify-between">
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
/*
|
||||
* @Description:
|
||||
* @Version: 0.0.1
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-01-13 11:44:06
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-01-15 20:59:57
|
||||
*/
|
||||
import React, { useState, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
@@ -47,7 +47,7 @@ const QuickActions: FC<QuickActionsProps> = ({ onNavigate }) => {
|
||||
key: 'space-management',
|
||||
icon: spaceIcon,
|
||||
title: t('quickActions.spaceManagement'),
|
||||
onClick: () => onNavigate?.('/spce')
|
||||
onClick: () => onNavigate?.('/space')
|
||||
},
|
||||
// {
|
||||
// key: 'workflow-orchestration',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { useEffect, useState, useRef, useCallback, type FC } from 'react';
|
||||
import { useNavigate, useParams, useLocation } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Switch, Button, Dropdown, Space, Modal, message, Radio } from 'antd';
|
||||
import { Switch, Button, Dropdown, Space, Modal, message, Radio, Tooltip } from 'antd';
|
||||
import type { MenuProps } from 'antd';
|
||||
import SearchInput from '@/components/SearchInput'
|
||||
import Table, { type TableRef } from '@/components/Table'
|
||||
@@ -564,6 +564,37 @@ const Private: FC = () => {
|
||||
</span>
|
||||
);
|
||||
}
|
||||
},{
|
||||
title: t('knowledgeBase.processMsg'),
|
||||
dataIndex: 'progress_msg',
|
||||
key: 'progress_msg',
|
||||
width: 320,
|
||||
render: (value: string) => {
|
||||
if (!value) return '-';
|
||||
|
||||
// 解析日志格式,将 \n 转换为换行
|
||||
const formattedText = value.replace(/\\n/g, '\n');
|
||||
|
||||
return (
|
||||
<Tooltip title={<pre style={{ margin: 0, whiteSpace: 'pre-wrap' }}>{formattedText}</pre>} placement="topLeft">
|
||||
<div
|
||||
style={{
|
||||
maxWidth: '320px',
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
display: '-webkit-box',
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: 'vertical',
|
||||
lineHeight: '1.5',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word'
|
||||
}}
|
||||
>
|
||||
{formattedText}
|
||||
</div>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: t('knowledgeBase.processingMode'),
|
||||
|
||||
@@ -292,7 +292,7 @@ const KnowledgeGraph: FC<KnowledgeGraphProps> = ({ data, loading = false }) => {
|
||||
if (params.dataType === 'node') {
|
||||
const node = params.data as KnowledgeNode
|
||||
return `
|
||||
<div>
|
||||
<div class="rb:max-w-[560px]">
|
||||
<div><strong>${node.entity_name}</strong></div>
|
||||
<div>类型: ${node.entity_type}</div>
|
||||
<div>重要度: ${(node.pagerank * 100).toFixed(2)}%</div>
|
||||
@@ -301,10 +301,10 @@ const KnowledgeGraph: FC<KnowledgeGraphProps> = ({ data, loading = false }) => {
|
||||
} else if (params.dataType === 'edge') {
|
||||
const edge = params.data as KnowledgeEdge
|
||||
return `
|
||||
<div>
|
||||
<div class="rb:max-w-[560px]">
|
||||
<div><strong>关系</strong></div>
|
||||
<div>权重: ${edge.weight}</div>
|
||||
<div>${edge.description}</div>
|
||||
<div class="rb:break-words rb:whitespace-pre-wrap">${edge.description}</div>
|
||||
</div>
|
||||
`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user