Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -728,9 +728,23 @@ async def draft_run_compare(
|
|||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||||
|
|
||||||
|
# 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||||
|
agent_model_params = agent_cfg.model_parameters
|
||||||
|
if hasattr(agent_model_params, 'model_dump'):
|
||||||
|
agent_model_params = agent_model_params.model_dump()
|
||||||
|
elif not isinstance(agent_model_params, dict):
|
||||||
|
agent_model_params = {}
|
||||||
|
|
||||||
|
# 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||||
|
item_model_params = model_item.model_parameters
|
||||||
|
if hasattr(item_model_params, 'model_dump'):
|
||||||
|
item_model_params = item_model_params.model_dump()
|
||||||
|
elif not isinstance(item_model_params, dict):
|
||||||
|
item_model_params = {}
|
||||||
|
|
||||||
merged_parameters = {
|
merged_parameters = {
|
||||||
**(agent_cfg.model_parameters or {}),
|
**(agent_model_params or {}),
|
||||||
**(model_item.model_parameters or {})
|
**(item_model_params or {})
|
||||||
}
|
}
|
||||||
|
|
||||||
model_configs.append({
|
model_configs.append({
|
||||||
|
|||||||
@@ -108,16 +108,23 @@ async def get_prompt_opt(
|
|||||||
service = PromptOptimizerService(db)
|
service = PromptOptimizerService(db)
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for chunk in service.optimize_prompt(
|
yield "event:start\ndata: {}\n\n"
|
||||||
tenant_id=current_user.tenant_id,
|
try:
|
||||||
model_id=data.model_id,
|
async for chunk in service.optimize_prompt(
|
||||||
session_id=session_id,
|
tenant_id=current_user.tenant_id,
|
||||||
user_id=current_user.id,
|
model_id=data.model_id,
|
||||||
current_prompt=data.current_prompt,
|
session_id=session_id,
|
||||||
user_require=data.message
|
user_id=current_user.id,
|
||||||
):
|
current_prompt=data.current_prompt,
|
||||||
# chunk 是 prompt 的增量内容
|
user_require=data.message
|
||||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
):
|
||||||
|
# chunk 是 prompt 的增量内容
|
||||||
|
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"event:error\ndata: {json.dumps(
|
||||||
|
{"error": str(e)}
|
||||||
|
)}\n\n"
|
||||||
|
yield "event:end\ndata: {}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
@@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService
|
|||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
|
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -288,7 +289,7 @@ async def chat(
|
|||||||
password = None # Token 认证不需要密码
|
password = None # Token 认证不需要密码
|
||||||
# end_user_id = user_id
|
# end_user_id = user_id
|
||||||
other_id = user_id
|
other_id = user_id
|
||||||
|
|
||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
@@ -364,6 +365,9 @@ async def chat(
|
|||||||
config = release.config or {}
|
config = release.config or {}
|
||||||
if not config.get("sub_agents"):
|
if not config.get("sub_agents"):
|
||||||
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
elif app_type == AppType.WORKFLOW:
|
||||||
|
# Multi-Agent 类型:验证多 Agent 配置
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
@@ -392,6 +396,8 @@ async def chat(
|
|||||||
|
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
# 流式返回
|
# 流式返回
|
||||||
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
# async def event_generator():
|
# async def event_generator():
|
||||||
# async for event in service.chat_stream(
|
# async for event in service.chat_stream(
|
||||||
@@ -424,7 +430,7 @@ async def chat(
|
|||||||
user_id= str(new_end_user.id), # 转换为字符串
|
user_id= str(new_end_user.id), # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
config=payload.agent_config,
|
config=agent_config,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
@@ -467,6 +473,7 @@ async def chat(
|
|||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
# config = workflow_config_4_app_release(release)
|
||||||
config = multi_agent_config_4_app_release(release)
|
config = multi_agent_config_4_app_release(release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -551,8 +558,71 @@ async def chat(
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
# return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
elif app_type == AppType.WORKFLOW:
|
||||||
|
|
||||||
|
config = workflow_config_4_app_release(release)
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in app_chat_service.workflow_chat_stream(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
app_id=release.app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
):
|
||||||
|
event_type = event.get("event", "message")
|
||||||
|
event_data = event.get("data", {})
|
||||||
|
|
||||||
|
# 转换为标准 SSE 格式(字符串)
|
||||||
|
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||||
|
yield sse_message
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多 Agent 非流式返回
|
||||||
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
app_id=release.app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"工作流试运行返回结果",
|
||||||
|
extra={
|
||||||
|
"result_type": str(type(result)),
|
||||||
|
"has_response": "response" in result if isinstance(result, dict) else False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data=result,
|
||||||
|
msg="工作流任务执行成功"
|
||||||
|
)
|
||||||
|
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""App 服务接口 - 基于 API Key 认证"""
|
"""App 服务接口 - 基于 API Key 认证"""
|
||||||
|
import json
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request, Body
|
from fastapi import APIRouter, Depends, Request, Body
|
||||||
@@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
|
|||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
|
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
from app.services.app_service import get_app_service, AppService
|
from app.services.app_service import get_app_service, AppService
|
||||||
|
|
||||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||||
@@ -226,22 +227,29 @@ async def chat(
|
|||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
# 多 Agent 流式返回
|
# 多 Agent 流式返回
|
||||||
config = dict_to_workflow_config(app.current_release.config,app.id)
|
config = workflow_config_4_app_release(app.current_release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.workflow_chat_stream(
|
async for event in app_chat_service.workflow_chat_stream(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=payload.web_search,
|
||||||
memory=memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
app_id=app.app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
):
|
):
|
||||||
yield event
|
event_type = event.get("event", "message")
|
||||||
|
event_data = event.get("data", {})
|
||||||
|
|
||||||
|
# 转换为标准 SSE 格式(字符串)
|
||||||
|
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||||
|
yield sse_message
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -253,21 +261,32 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=payload.web_search,
|
||||||
memory=memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
app_id=app.app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"工作流试运行返回结果",
|
||||||
|
extra={
|
||||||
|
"result_type": str(type(result)),
|
||||||
|
"has_response": "response" in result if isinstance(result, dict) else False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data=result,
|
||||||
|
msg="工作流任务执行成功"
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
|
||||||
else:
|
else:
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ from app.db import get_db
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.api_key_utils import timestamp_to_datetime
|
||||||
from app.services.user_memory_service import (
|
from app.services.user_memory_service import (
|
||||||
UserMemoryService,
|
UserMemoryService,
|
||||||
analytics_node_statistics,
|
|
||||||
analytics_memory_types,
|
analytics_memory_types,
|
||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
)
|
)
|
||||||
@@ -41,24 +41,27 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str, # 使用 end_user_id
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""获取缓存的记忆洞察报告"""
|
"""
|
||||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
获取缓存的记忆洞察报告
|
||||||
|
|
||||||
|
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||||
|
如需生成新的洞察报告,请使用专门的生成接口。
|
||||||
|
"""
|
||||||
|
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
# 缓存存在,返回缓存数据
|
|
||||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
else:
|
else:
|
||||||
# 缓存不存在,返回提示消息
|
|
||||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="数据尚未生成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||||
@@ -66,24 +69,27 @@ async def get_memory_insight_report_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str, # 使用 end_user_id
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""获取缓存的用户摘要"""
|
"""
|
||||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
获取缓存的用户摘要
|
||||||
|
|
||||||
|
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||||
|
如需生成新的用户摘要,请使用专门的生成接口。
|
||||||
|
"""
|
||||||
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
# 缓存存在,返回缓存数据
|
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
else:
|
else:
|
||||||
# 缓存不存在,返回提示消息
|
|
||||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="数据尚未生成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||||
@@ -351,7 +357,7 @@ async def update_end_user_profile(
|
|||||||
if 'hire_date' in update_data:
|
if 'hire_date' in update_data:
|
||||||
hire_date_timestamp = update_data['hire_date']
|
hire_date_timestamp = update_data['hire_date']
|
||||||
if hire_date_timestamp is not None:
|
if hire_date_timestamp is not None:
|
||||||
update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp)
|
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||||
# 如果是 None,保持 None(允许清空)
|
# 如果是 None,保持 None(允许清空)
|
||||||
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
|
|||||||
@@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system.
|
|||||||
|
|
||||||
Available functions:
|
Available functions:
|
||||||
- get_hot_memory_tags: Get hot memory tags by frequency
|
- get_hot_memory_tags: Get hot memory tags by frequency
|
||||||
- MemoryInsight: Generate memory insight reports
|
|
||||||
- get_recent_activity_stats: Get recent activity statistics
|
- get_recent_activity_stats: Get recent activity statistics
|
||||||
- generate_user_summary: Generate user summary
|
|
||||||
|
Note: MemoryInsight and generate_user_summary have been moved to
|
||||||
|
app.services.user_memory_service for better architecture.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
|
||||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_hot_memory_tags",
|
"get_hot_memory_tags",
|
||||||
"MemoryInsight",
|
|
||||||
"get_recent_activity_stats",
|
"get_recent_activity_stats",
|
||||||
"generate_user_summary",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,327 +0,0 @@
|
|||||||
"""
|
|
||||||
This module provides the MemoryInsight class for analyzing user memory data.
|
|
||||||
|
|
||||||
MemoryInsight 是一个工具类,提供基础的数据获取和分析功能:
|
|
||||||
- get_domain_distribution(): 获取记忆领域分布
|
|
||||||
- get_active_periods(): 获取活跃时段
|
|
||||||
- get_social_connections(): 获取社交关联
|
|
||||||
|
|
||||||
业务逻辑(如生成洞察报告)应该在服务层(user_memory_service.py)中实现。
|
|
||||||
|
|
||||||
This script can be executed directly to test the memory insight generation for a test user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from collections import Counter
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# To run this script directly, we need to add the src directory to the Python path
|
|
||||||
# to resolve the inconsistent imports in other modules.
|
|
||||||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
if src_path not in sys.path:
|
|
||||||
sys.path.insert(0, src_path)
|
|
||||||
|
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
#TODO: Fix this
|
|
||||||
|
|
||||||
# Default values (previously from definitions.py)
|
|
||||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
|
||||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
|
||||||
|
|
||||||
# 定义用于LLM结构化输出的Pydantic模型
|
|
||||||
class TagClassification(BaseModel):
|
|
||||||
"""
|
|
||||||
Represents the classification of a tag into a specific domain.
|
|
||||||
"""
|
|
||||||
|
|
||||||
domain: str = Field(
|
|
||||||
...,
|
|
||||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
|
||||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
|
||||||
)
|
|
||||||
|
|
||||||
class InsightReport(BaseModel):
|
|
||||||
"""
|
|
||||||
Represents the final insight report generated by the LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
report: str = Field(
|
|
||||||
...,
|
|
||||||
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryInsight:
|
|
||||||
"""
|
|
||||||
Provides insights into user memories by analyzing various aspects of their data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, user_id: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.neo4j_connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# Get config_id using get_end_user_connected_config
|
|
||||||
with get_db_context() as db:
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
connected_config = get_end_user_connected_config(user_id, db)
|
|
||||||
config_id = connected_config.get("memory_config_id")
|
|
||||||
|
|
||||||
if config_id:
|
|
||||||
# Use the config_id to get the proper LLM client
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
memory_config = config_service.load_memory_config(config_id)
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
|
||||||
else:
|
|
||||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
|
||||||
# Fallback to default LLM if no config found
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
|
||||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
|
||||||
# Fallback to default LLM
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""关闭数据库连接。"""
|
|
||||||
await self.neo4j_connector.close()
|
|
||||||
|
|
||||||
async def get_domain_distribution(self) -> dict[str, float]:
|
|
||||||
"""
|
|
||||||
Calculates the distribution of memory domains based on hot tags.
|
|
||||||
"""
|
|
||||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
|
||||||
if not hot_tags:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
domain_counts = Counter()
|
|
||||||
for tag, _ in hot_tags:
|
|
||||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
|
||||||
|
|
||||||
可选领域及其关键词:
|
|
||||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
|
||||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
|
||||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
|
||||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
|
||||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
|
||||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
|
||||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
|
||||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
|
||||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
|
||||||
- 其他:确实无法归入以上任何类别的内容
|
|
||||||
|
|
||||||
标签: {tag}
|
|
||||||
|
|
||||||
分析步骤:
|
|
||||||
1. 仔细理解标签的核心含义和使用场景
|
|
||||||
2. 对比各个领域的关键词,找到最匹配的领域
|
|
||||||
3. 特别注意:
|
|
||||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
|
||||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
|
||||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
|
||||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
|
||||||
|
|
||||||
请直接返回最合适的领域名称。"""
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
# 直接调用并等待结果
|
|
||||||
classification = await self.llm_client.response_structured(
|
|
||||||
messages=messages,
|
|
||||||
response_model=TagClassification,
|
|
||||||
)
|
|
||||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
|
||||||
domain_counts[classification.domain] += 1
|
|
||||||
|
|
||||||
total_tags = sum(domain_counts.values())
|
|
||||||
if total_tags == 0:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
domain_distribution = {
|
|
||||||
domain: count / total_tags for domain, count in domain_counts.items()
|
|
||||||
}
|
|
||||||
return dict(
|
|
||||||
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_active_periods(self) -> list[int]:
|
|
||||||
"""
|
|
||||||
Identifies the top 2 most active months for the user.
|
|
||||||
Only returns months if there is valid and diverse time data.
|
|
||||||
|
|
||||||
This method checks if the time data represents real user memory timestamps
|
|
||||||
rather than auto-generated system timestamps by verifying:
|
|
||||||
1. Time data exists and is parseable
|
|
||||||
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
|
|
||||||
"""
|
|
||||||
query = f"""
|
|
||||||
MATCH (d:Dialogue)
|
|
||||||
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
|
|
||||||
RETURN d.created_at AS creation_time
|
|
||||||
"""
|
|
||||||
records = await self.neo4j_connector.execute_query(query)
|
|
||||||
|
|
||||||
if not records:
|
|
||||||
return []
|
|
||||||
|
|
||||||
month_counts = Counter()
|
|
||||||
valid_dates_count = 0
|
|
||||||
for record in records:
|
|
||||||
creation_time_str = record.get("creation_time")
|
|
||||||
if not creation_time_str:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
# 尝试解析时间字符串
|
|
||||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
|
||||||
month_counts[dt_object.month] += 1
|
|
||||||
valid_dates_count += 1
|
|
||||||
except (ValueError, TypeError, AttributeError):
|
|
||||||
# 如果解析失败,跳过这条记录
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 如果没有有效的时间数据,返回空列表
|
|
||||||
if not month_counts or valid_dates_count == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 检查时间分布是否过于集中(可能是批量导入的数据)
|
|
||||||
# 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间
|
|
||||||
unique_months = len(month_counts)
|
|
||||||
if unique_months <= 2:
|
|
||||||
# 只有1-2个月有数据,很可能是批量导入
|
|
||||||
most_common_count = month_counts.most_common(1)[0][1]
|
|
||||||
if most_common_count / valid_dates_count > 0.8:
|
|
||||||
# 超过80%集中在一个月,认为是系统时间戳
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 如果时间分布较为分散(3个月以上),认为是真实时间数据
|
|
||||||
if unique_months >= 3:
|
|
||||||
most_common_months = month_counts.most_common(2)
|
|
||||||
return [month for month, _ in most_common_months]
|
|
||||||
|
|
||||||
# 2个月的情况,检查是否分布均匀
|
|
||||||
if unique_months == 2:
|
|
||||||
counts = list(month_counts.values())
|
|
||||||
# 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据
|
|
||||||
ratio = min(counts) / max(counts)
|
|
||||||
if ratio > 0.3:
|
|
||||||
most_common_months = month_counts.most_common(2)
|
|
||||||
return [month for month, _ in most_common_months]
|
|
||||||
|
|
||||||
# 其他情况返回空列表
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def get_social_connections(self) -> dict | None:
|
|
||||||
"""
|
|
||||||
Finds the user with whom the most memories are shared.
|
|
||||||
使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。
|
|
||||||
"""
|
|
||||||
# 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆
|
|
||||||
query = f"""
|
|
||||||
MATCH (c1:Chunk {{group_id: '{self.user_id}'}})
|
|
||||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
|
||||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
|
||||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
|
||||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
|
||||||
WHERE common_statements > 0
|
|
||||||
RETURN other_user_id, common_statements
|
|
||||||
ORDER BY common_statements DESC
|
|
||||||
LIMIT 1
|
|
||||||
"""
|
|
||||||
records = await self.neo4j_connector.execute_query(query)
|
|
||||||
if not records or not records[0].get("other_user_id"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
most_connected_user = records[0]["other_user_id"]
|
|
||||||
common_memories_count = records[0]["common_statements"]
|
|
||||||
|
|
||||||
# 使用 Chunk 的时间范围
|
|
||||||
time_range_query = f"""
|
|
||||||
MATCH (c:Chunk)
|
|
||||||
WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}']
|
|
||||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
|
||||||
"""
|
|
||||||
time_records = await self.neo4j_connector.execute_query(time_range_query)
|
|
||||||
start_year, end_year = "N/A", "N/A"
|
|
||||||
if time_records and time_records[0]["start_time"]:
|
|
||||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
|
||||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
|
||||||
|
|
||||||
return {
|
|
||||||
"user_id": most_connected_user,
|
|
||||||
"common_memories_count": common_memories_count,
|
|
||||||
"time_range": f"{start_year}-{end_year}",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""
|
|
||||||
Closes the database connection.
|
|
||||||
"""
|
|
||||||
await self.neo4j_connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""
|
|
||||||
Initializes and runs the memory insight analysis for a test user.
|
|
||||||
"""
|
|
||||||
# 默认从环境变量读取
|
|
||||||
test_user_id = DEFAULT_GROUP_ID
|
|
||||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用服务层函数生成报告
|
|
||||||
from app.services.user_memory_service import analytics_memory_insight_report
|
|
||||||
|
|
||||||
result = await analytics_memory_insight_report(end_user_id=test_user_id)
|
|
||||||
report = result.get("report", "")
|
|
||||||
|
|
||||||
print("--- 记忆洞察报告 ---")
|
|
||||||
print(report)
|
|
||||||
print("---------------------")
|
|
||||||
|
|
||||||
# 将结果写入统一的 User-Dashboard.json,使用全局配置路径
|
|
||||||
try:
|
|
||||||
from app.core.config import settings
|
|
||||||
settings.ensure_memory_output_dir()
|
|
||||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
|
||||||
try:
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
|
||||||
existing = {}
|
|
||||||
if os.path.exists(dashboard_path):
|
|
||||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
|
||||||
existing = json.load(rf)
|
|
||||||
existing["memory_insight"] = {
|
|
||||||
"group_id": test_user_id,
|
|
||||||
"report": report
|
|
||||||
}
|
|
||||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
|
||||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
|
||||||
print(f"已写入 {dashboard_path} -> memory_insight")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"生成报告时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# This setup allows running the async main function
|
|
||||||
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
"""
|
|
||||||
Generate a concise "关于我" style user summary using data from Neo4j
|
|
||||||
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m analytics.user_summary --user_id <group_id>
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
# Ensure absolute imports work whether executed directly or via module
|
|
||||||
try:
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
||||||
src_path = os.path.join(project_root, 'src')
|
|
||||||
if src_path not in sys.path:
|
|
||||||
sys.path.insert(0, src_path)
|
|
||||||
if project_root not in sys.path:
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
#TODO: Fix this
|
|
||||||
|
|
||||||
# Default values (previously from definitions.py)
|
|
||||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
|
||||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StatementRecord:
|
|
||||||
statement: str
|
|
||||||
created_at: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class UserSummary:
|
|
||||||
"""Builds a textual user summary for a given user/group id."""
|
|
||||||
|
|
||||||
def __init__(self, user_id: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.connector = Neo4jConnector()
|
|
||||||
|
|
||||||
# Get config_id using get_end_user_connected_config
|
|
||||||
with get_db_context() as db:
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
connected_config = get_end_user_connected_config(user_id, db)
|
|
||||||
config_id = connected_config.get("memory_config_id")
|
|
||||||
|
|
||||||
if config_id:
|
|
||||||
# Use the config_id to get the proper LLM client
|
|
||||||
config_service = MemoryConfigService(db)
|
|
||||||
memory_config = config_service.load_memory_config(config_id)
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
|
||||||
else:
|
|
||||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
|
||||||
# Fallback to default LLM if no config found
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
|
||||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
|
||||||
# Fallback to default LLM
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.connector.close()
|
|
||||||
|
|
||||||
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service
|
|
||||||
"""Fetch recent statements authored by the user/group for context."""
|
|
||||||
query = (
|
|
||||||
"MATCH (s:Statement) "
|
|
||||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
|
||||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
|
||||||
"ORDER BY created_at DESC LIMIT $limit"
|
|
||||||
)
|
|
||||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
|
||||||
records: List[StatementRecord] = []
|
|
||||||
for r in rows:
|
|
||||||
try:
|
|
||||||
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
return records
|
|
||||||
|
|
||||||
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
|
||||||
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
|
|
||||||
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
|
|
||||||
return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless
|
|
||||||
"""
|
|
||||||
生成用户摘要的便捷函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: 可选的用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
用户摘要字符串
|
|
||||||
"""
|
|
||||||
# 导入服务层函数
|
|
||||||
from app.services.user_memory_service import analytics_user_summary
|
|
||||||
|
|
||||||
# 调用服务层函数
|
|
||||||
result = await analytics_user_summary(user_id)
|
|
||||||
return result.get("summary", "")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("开始生成用户摘要…")
|
|
||||||
try:
|
|
||||||
# 直接使用 runtime.json 中的 group_id
|
|
||||||
summary = asyncio.run(generate_user_summary())
|
|
||||||
print("\n— 用户摘要 —\n")
|
|
||||||
print(summary)
|
|
||||||
|
|
||||||
# 将结果写入统一的 User-Dashboard.json
|
|
||||||
try:
|
|
||||||
from app.core.config import settings
|
|
||||||
settings.ensure_memory_output_dir()
|
|
||||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
|
||||||
try:
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
|
||||||
existing = {}
|
|
||||||
if os.path.exists(dashboard_path):
|
|
||||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
|
||||||
existing = json.load(rf)
|
|
||||||
existing["user_summary"] = {
|
|
||||||
"group_id": DEFAULT_GROUP_ID,
|
|
||||||
"summary": summary
|
|
||||||
}
|
|
||||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
|
||||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
|
||||||
print(f"已写入 {dashboard_path} -> user_summary")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"生成摘要失败: {e}")
|
|
||||||
print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。")
|
|
||||||
@@ -85,33 +85,21 @@ Example Output:
|
|||||||
===End of Example===
|
===End of Example===
|
||||||
|
|
||||||
|
|
||||||
===Reflection Process===
|
===Internal Quality Checks (DO NOT OUTPUT)===
|
||||||
|
|
||||||
After generating the profile, perform the following self-review steps:
|
Before generating your final output, internally verify:
|
||||||
|
1. All content is grounded in provided data (no fabrication)
|
||||||
|
2. Format follows the specified structure with correct headers
|
||||||
|
3. Tone is objective, third-person, and neutral
|
||||||
|
4. All four sections are complete and within character limits
|
||||||
|
|
||||||
**Step 1: Data Grounding Check**
|
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
|
||||||
- Verify all statements are supported by the provided entities and statements
|
|
||||||
- Ensure no fabricated or speculated information is included
|
|
||||||
- Confirm all claims can be traced back to the input data
|
|
||||||
|
|
||||||
**Step 2: Format Compliance**
|
|
||||||
- Verify each section follows the specified format with section headers
|
|
||||||
- Check character count limits for each section
|
|
||||||
- Ensure proper use of section markers (【】)
|
|
||||||
|
|
||||||
**Step 3: Tone and Style Review**
|
|
||||||
- Confirm objective third-person perspective is maintained
|
|
||||||
- Check for excessive adjectives or empty phrases
|
|
||||||
- Verify neutral and restrained tone throughout
|
|
||||||
|
|
||||||
**Step 4: Completeness Check**
|
|
||||||
- Ensure all four sections are present and complete
|
|
||||||
- Verify each section addresses its specific focus area
|
|
||||||
- Confirm the one-sentence summary effectively captures the user's essence
|
|
||||||
|
|
||||||
|
|
||||||
===Output Requirements===
|
===Output Requirements===
|
||||||
|
|
||||||
|
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
|
||||||
|
|
||||||
**LANGUAGE REQUIREMENT:**
|
**LANGUAGE REQUIREMENT:**
|
||||||
- The output language should ALWAYS be Chinese (Simplified)
|
- The output language should ALWAYS be Chinese (Simplified)
|
||||||
- All section content must be in Chinese
|
- All section content must be in Chinese
|
||||||
@@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps:
|
|||||||
- Content follows immediately after the header
|
- Content follows immediately after the header
|
||||||
- Sections are separated by blank lines
|
- Sections are separated by blank lines
|
||||||
- Strictly adhere to character limits for each section
|
- Strictly adhere to character limits for each section
|
||||||
|
- **DO NOT include any text after the 【一句话总结】 section**
|
||||||
|
- **DO NOT output reflection steps, self-review, or verification notes**
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
|||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
looping: bool
|
looping: Annotated[bool, lambda x, y: x and y]
|
||||||
|
|
||||||
# Input variables (passed from configured variables)
|
# Input variables (passed from configured variables)
|
||||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||||
|
|||||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
|||||||
retries -= 1
|
retries -= 1
|
||||||
if retries > 0:
|
if retries > 0:
|
||||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||||
|
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||||
else:
|
else:
|
||||||
match self.typed_config.error_handle.method:
|
match self.typed_config.error_handle.method:
|
||||||
case HttpErrorHandle.NONE:
|
|
||||||
logger.warning(
|
|
||||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
|
||||||
)
|
|
||||||
return HttpRequestNodeOutput(
|
|
||||||
body="",
|
|
||||||
status_code=resp.status_code,
|
|
||||||
headers=resp.headers,
|
|
||||||
).model_dump()
|
|
||||||
case HttpErrorHandle.DEFAULT:
|
case HttpErrorHandle.DEFAULT:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
|||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return "ERROR"
|
return "ERROR"
|
||||||
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.similarity_threshold)
|
score_threshold=kb_config.similarity_threshold)
|
||||||
# Deduplicate hybrid retrieval results
|
# Deduplicate hy brid retrieval results
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
vector_service.reranker = self.get_reranker_model()
|
vector_service.reranker = self.get_reranker_model()
|
||||||
|
# TODO:其他重排序方式支持
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
)
|
)
|
||||||
return [chunk.model_dump() for chunk in final_rs]
|
return [chunk.page_content for chunk in final_rs]
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""LLM 节点配置"""
|
"""LLM 节点配置"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
|||||||
|
|
||||||
class MessageConfig(BaseModel):
|
class MessageConfig(BaseModel):
|
||||||
"""消息配置"""
|
"""消息配置"""
|
||||||
|
|
||||||
role: str = Field(
|
role: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息角色:system, user, assistant"
|
description="消息角色:system, user, assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
content: str = Field(
|
content: str = Field(
|
||||||
...,
|
...,
|
||||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("role")
|
@field_validator("role")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_role(cls, v: str) -> str:
|
def validate_role(cls, v: str) -> str:
|
||||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
1. 简单模式:使用 prompt 字段
|
1. 简单模式:使用 prompt 字段
|
||||||
2. 消息模式:使用 messages 字段(推荐)
|
2. 消息模式:使用 messages 字段(推荐)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_id: str = Field(
|
model_id: str = Field(
|
||||||
...,
|
...,
|
||||||
description="模型配置 ID"
|
description="模型配置 ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context: Any = Field(
|
||||||
|
default="",
|
||||||
|
description="上下文"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="提示词模板(简单模式),支持变量引用"
|
description="提示词模板(简单模式),支持变量引用"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 消息模式(推荐)
|
# 消息模式(推荐)
|
||||||
messages: list[MessageConfig] | None = Field(
|
messages: list[MessageConfig] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="消息列表(消息模式),支持多轮对话"
|
description="消息列表(消息模式),支持多轮对话"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 模型参数
|
# 模型参数
|
||||||
temperature: float | None = Field(
|
temperature: float | None = Field(
|
||||||
default=0.7,
|
default=0.7,
|
||||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
le=2.0,
|
le=2.0,
|
||||||
description="温度参数,控制输出的随机性"
|
description="温度参数,控制输出的随机性"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_tokens: int | None = Field(
|
max_tokens: int | None = Field(
|
||||||
default=1000,
|
default=1000,
|
||||||
ge=1,
|
ge=1,
|
||||||
le=32000,
|
le=32000,
|
||||||
description="最大生成 token 数"
|
description="最大生成 token 数"
|
||||||
)
|
)
|
||||||
|
|
||||||
top_p: float | None = Field(
|
top_p: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
le=1.0,
|
le=1.0,
|
||||||
description="Top-p 采样参数"
|
description="Top-p 采样参数"
|
||||||
)
|
)
|
||||||
|
|
||||||
frequency_penalty: float | None = Field(
|
frequency_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="频率惩罚"
|
description="频率惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
presence_penalty: float | None = Field(
|
presence_penalty: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
ge=-2.0,
|
ge=-2.0,
|
||||||
le=2.0,
|
le=2.0,
|
||||||
description="存在惩罚"
|
description="存在惩罚"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出变量定义
|
# 输出变量定义
|
||||||
output_variables: list[VariableDefinition] = Field(
|
output_variables: list[VariableDefinition] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
],
|
],
|
||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v, info):
|
def validate_input_mode(cls, v, info):
|
||||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||||
# 这个验证在 model_validator 中更合适
|
# 这个验证在 model_validator 中更合适
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
"examples": [
|
"examples": [
|
||||||
|
|||||||
@@ -5,15 +5,17 @@ LLM 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
|
|||||||
- user/human: 用户消息(HumanMessage)
|
- user/human: 用户消息(HumanMessage)
|
||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
|
def _render_context(self, message,state):
|
||||||
|
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||||
|
return re.sub(r"{{context}}", context, message)
|
||||||
|
|
||||||
|
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||||
"""准备 LLM 实例(公共逻辑)
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -76,15 +85,16 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 1. 处理消息格式(优先使用 messages)
|
# 1. 处理消息格式(优先使用 messages)
|
||||||
messages_config = self.config.get("messages")
|
messages_config = self.config.get("messages")
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
for msg_config in messages_config:
|
for msg_config in messages_config:
|
||||||
role = msg_config.get("role", "user").lower()
|
role = msg_config.get("role", "user").lower()
|
||||||
content_template = msg_config.get("content", "")
|
content_template = msg_config.get("content", "")
|
||||||
|
content_template = self._render_context(content_template, state)
|
||||||
content = self._render_template(content_template, state)
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append(SystemMessage(content=content))
|
messages.append(SystemMessage(content=content))
|
||||||
@@ -95,7 +105,7 @@ class LLMNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append(HumanMessage(content=content))
|
messages.append(HumanMessage(content=content))
|
||||||
|
|
||||||
prompt_or_messages = messages
|
prompt_or_messages = messages
|
||||||
else:
|
else:
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
@@ -106,17 +116,17 @@ class LLMNode(BaseNode):
|
|||||||
model_id = self.config.get("model_id")
|
model_id = self.config.get("model_id")
|
||||||
if not model_id:
|
if not model_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||||
|
|
||||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
# 在 Session 关闭前提取所有需要的数据
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
api_config = config.api_keys[0]
|
api_config = config.api_keys[0]
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
@@ -124,26 +134,26 @@ class LLMNode(BaseNode):
|
|||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||||
extra_params = {"streaming": stream} if stream else {}
|
extra_params = {"streaming": stream} if stream else {}
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
extra_params=extra_params
|
extra_params=extra_params
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||||
|
|
||||||
return llm, prompt_or_messages
|
return llm, prompt_or_messages
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -153,10 +163,10 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(prompt_or_messages)
|
response = await llm.ainvoke(prompt_or_messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
@@ -164,16 +174,16 @@ class LLMNode(BaseNode):
|
|||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
# 返回 AIMessage(包含响应元数据)
|
# 返回 AIMessage(包含响应元数据)
|
||||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
_, prompt_or_messages = self._prepare_llm(state)
|
_, prompt_or_messages = self._prepare_llm(state)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -186,13 +196,13 @@ class LLMNode(BaseNode):
|
|||||||
"max_tokens": self.config.get("max_tokens")
|
"max_tokens": self.config.get("max_tokens")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> str:
|
def _extract_output(self, business_result: Any) -> str:
|
||||||
"""从 AIMessage 中提取文本内容"""
|
"""从 AIMessage 中提取文本内容"""
|
||||||
if isinstance(business_result, AIMessage):
|
if isinstance(business_result, AIMessage):
|
||||||
return business_result.content
|
return business_result.content
|
||||||
return str(business_result)
|
return str(business_result)
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从 AIMessage 中提取 token 使用情况"""
|
"""从 AIMessage 中提取 token 使用情况"""
|
||||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||||
@@ -204,7 +214,7 @@ class LLMNode(BaseNode):
|
|||||||
"total_tokens": usage.get('total_tokens', 0)
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
"""流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -215,26 +225,26 @@ class LLMNode(BaseNode):
|
|||||||
文本片段(chunk)或完成标记
|
文本片段(chunk)或完成标记
|
||||||
"""
|
"""
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
# 检查是否有注入的 End 节点前缀配置
|
# 检查是否有注入的 End 节点前缀配置
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||||
|
|
||||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||||
if end_prefix:
|
if end_prefix:
|
||||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||||
|
|
||||||
if end_prefix:
|
if end_prefix:
|
||||||
# 渲染前缀(可能包含其他变量)
|
# 渲染前缀(可能包含其他变量)
|
||||||
try:
|
try:
|
||||||
rendered_prefix = self._render_template(end_prefix, state)
|
rendered_prefix = self._render_template(end_prefix, state)
|
||||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||||
|
|
||||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||||
writer({
|
writer({
|
||||||
"type": "message", # End 相关的内容都是 message 类型
|
"type": "message", # End 相关的内容都是 message 类型
|
||||||
@@ -246,12 +256,12 @@ class LLMNode(BaseNode):
|
|||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
last_chunk = None
|
last_chunk = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
async for chunk in llm.astream(prompt_or_messages):
|
async for chunk in llm.astream(prompt_or_messages):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
@@ -259,18 +269,18 @@ class LLMNode(BaseNode):
|
|||||||
content = chunk.content
|
content = chunk.content
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
|
|
||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
last_chunk = chunk
|
last_chunk = chunk
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
if isinstance(last_chunk, AIMessage):
|
if isinstance(last_chunk, AIMessage):
|
||||||
final_message = AIMessage(
|
final_message = AIMessage(
|
||||||
@@ -279,6 +289,6 @@ class LLMNode(BaseNode):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_message = AIMessage(content=full_response)
|
final_message = AIMessage(content=full_response)
|
||||||
|
|
||||||
# yield 完成标记
|
# yield 完成标记
|
||||||
yield {"__final__": True, "result": final_message}
|
yield {"__final__": True, "result": final_message}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().read_memory(
|
return await MemoryAgentService().read_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
|
|
||||||
return await MemoryAgentService().write_memory(
|
return await MemoryAgentService().write_memory(
|
||||||
group_id=end_user_id,
|
group_id=end_user_id,
|
||||||
message=self.typed_config.message,
|
message=self._render_template(self.typed_config.message, state),
|
||||||
config_id=self.typed_config.config_id,
|
config_id=self.typed_config.config_id,
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type="neo4j",
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
category_map[category_name] = case_tag
|
category_map[category_name] = case_tag
|
||||||
return category_map
|
return category_map
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> str:
|
async def execute(self, state: WorkflowState) -> dict:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
question = self.typed_config.input_variable
|
question = self.typed_config.input_variable
|
||||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||||
)
|
)
|
||||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
if category_count > 0:
|
||||||
|
return {
|
||||||
|
"class_name": category_names[0],
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"class_name": "unknown",
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = self._get_llm_instance()
|
llm = self._get_llm_instance()
|
||||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||||
|
|
||||||
return f"CASE{category_names.index(category) + 1}"
|
return {
|
||||||
|
"class_name": category,
|
||||||
|
"output": f"CASE{category_names.index(category) + 1}",
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
)
|
)
|
||||||
# 异常时返回默认分支,保证工作流容错性
|
# 异常时返回默认分支,保证工作流容错性
|
||||||
if category_count > 0:
|
if category_count > 0:
|
||||||
return DEFAULT_EMPTY_QUESTION_CASE
|
return {
|
||||||
return "unknown"
|
"class_name": category_names[0],
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"class_name": "unknown",
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
|||||||
"""工具节点配置"""
|
"""工具节点配置"""
|
||||||
|
|
||||||
tool_id: str = Field(..., description="工具ID")
|
tool_id: str = Field(..., description="工具ID")
|
||||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
|||||||
|
|
||||||
# 如果没有租户ID,尝试从工作流ID获取
|
# 如果没有租户ID,尝试从工作流ID获取
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||||
if workflow_id:
|
if workspace_id:
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
return {
|
||||||
# return {"error": "缺少租户ID"}
|
"success": False,
|
||||||
|
"data": "缺少租户ID"
|
||||||
|
}
|
||||||
|
|
||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||||
rendered_value = self._render_template(param_template, state)
|
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||||
|
try:
|
||||||
|
rendered_value = self._render_template(param_template, state)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||||
|
else:
|
||||||
|
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||||
|
rendered_value = param_template
|
||||||
rendered_parameters[param_name] = rendered_value
|
rendered_parameters[param_name] = rendered_value
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||||
print(self.typed_config.tool_id)
|
|
||||||
|
|
||||||
# 执行工具
|
# 执行工具
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
print(result)
|
|
||||||
if result.success:
|
if result.success:
|
||||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||||
return {
|
return {
|
||||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
|||||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": result.error,
|
"data": result.error,
|
||||||
"error_code": result.error_code,
|
"error_code": result.error_code,
|
||||||
"execution_time": result.execution_time
|
"execution_time": result.execution_time
|
||||||
}
|
}
|
||||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
|||||||
return graphs
|
return graphs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
publish: 发布验证标识
|
||||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
|||||||
|
|
||||||
graphs = cls.get_subgraph(workflow_config)
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
logger.info(graphs)
|
logger.info(graphs)
|
||||||
for graph in graphs:
|
for index, graph in enumerate(graphs):
|
||||||
nodes = graph.get("nodes", [])
|
nodes = graph.get("nodes", [])
|
||||||
edges = graph.get("edges", [])
|
edges = graph.get("edges", [])
|
||||||
variables = graph.get("variables", [])
|
variables = graph.get("variables", [])
|
||||||
@@ -125,10 +126,11 @@ class WorkflowValidator:
|
|||||||
elif len(start_nodes) > 1:
|
elif len(start_nodes) > 1:
|
||||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
|
|
||||||
# 2. 验证 end 节点(至少一个)
|
if index == len(graphs) - 1:
|
||||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
# 2. 验证 主图end 节点(至少一个)
|
||||||
if len(end_nodes) == 0:
|
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||||
errors.append("工作流必须至少有一个 end 节点")
|
if len(end_nodes) == 0:
|
||||||
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
# 3. 验证节点 ID 唯一性
|
# 3. 验证节点 ID 唯一性
|
||||||
node_ids = [n.get("id") for n in nodes]
|
node_ids = [n.get("id") for n in nodes]
|
||||||
@@ -159,15 +161,17 @@ class WorkflowValidator:
|
|||||||
elif target not in node_id_set:
|
elif target not in node_id_set:
|
||||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||||
|
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
if publish:
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
# 仅在发布时验证所有节点可达
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
start_nodes[0]["id"],
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
edges
|
reachable = WorkflowValidator._get_reachable_nodes(
|
||||||
)
|
start_nodes[0]["id"],
|
||||||
unreachable = node_id_set - reachable
|
edges
|
||||||
if unreachable:
|
)
|
||||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
unreachable = node_id_set - reachable
|
||||||
|
if unreachable:
|
||||||
|
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||||
|
|
||||||
# 7. 检测循环依赖(非 loop 节点)
|
# 7. 检测循环依赖(非 loop 节点)
|
||||||
if not errors: # 只有在前面验证通过时才检查循环
|
if not errors: # 只有在前面验证通过时才检查循环
|
||||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
|||||||
(is_valid, errors): 是否有效和错误列表
|
(is_valid, errors): 是否有效和错误列表
|
||||||
"""
|
"""
|
||||||
# 先执行基础验证
|
# 先执行基础验证
|
||||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return False, errors
|
return False, errors
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
|||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
from app.models.multi_agent_model import PydanticType
|
||||||
|
from app.schemas import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(Base):
|
class AgentConfig(Base):
|
||||||
@@ -17,14 +19,17 @@ class AgentConfig(Base):
|
|||||||
# Agent 行为配置
|
# Agent 行为配置
|
||||||
system_prompt = Column(Text, nullable=True, comment="系统提示词")
|
system_prompt = Column(Text, nullable=True, comment="系统提示词")
|
||||||
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
|
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
|
||||||
|
|
||||||
# 结构化配置(直接存储 JSON)
|
# 结构化配置(直接存储 JSON)
|
||||||
model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
# model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||||
|
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
|
||||||
|
comment="模型参数配置(temperature、max_tokens等)")
|
||||||
|
|
||||||
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
||||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||||
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
||||||
|
|
||||||
# 多 Agent 相关字段
|
# 多 Agent 相关字段
|
||||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||||
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
|
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
|
||||||
@@ -41,4 +46,4 @@ class AgentConfig(Base):
|
|||||||
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
|
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
|
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
|
||||||
|
|||||||
@@ -38,6 +38,33 @@ class ToolRepository:
|
|||||||
|
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||||||
|
"""
|
||||||
|
根据空间ID获取tenant_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
workspace_id: 空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tenant_id或None
|
||||||
|
"""
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
|
||||||
|
tenant_id = db.query(Workspace.tenant_id).filter(
|
||||||
|
Workspace.id == workspace_id
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||||||
|
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||||||
|
try:
|
||||||
|
tenant_id = uuid.UUID(tenant_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_by_tenant(
|
def find_by_tenant(
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -86,7 +86,12 @@ class AgentConfigConverter:
|
|||||||
# 1. 解析模型参数配置
|
# 1. 解析模型参数配置
|
||||||
if model_parameters:
|
if model_parameters:
|
||||||
from app.schemas.app_schema import ModelParameters
|
from app.schemas.app_schema import ModelParameters
|
||||||
result["model_parameters"] = ModelParameters(**model_parameters)
|
if isinstance(model_parameters, ModelParameters):
|
||||||
|
result["model_parameters"] = model_parameters
|
||||||
|
elif isinstance(model_parameters, dict):
|
||||||
|
result["model_parameters"] = ModelParameters(**model_parameters)
|
||||||
|
else:
|
||||||
|
result["model_parameters"] = ModelParameters()
|
||||||
|
|
||||||
# 2. 解析知识库检索配置
|
# 2. 解析知识库检索配置
|
||||||
if knowledge_retrieval:
|
if knowledge_retrieval:
|
||||||
|
|||||||
@@ -9,15 +9,18 @@ from fastapi import Depends
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.db import get_db
|
from app.db import get_db, get_db_context
|
||||||
from app.models import MultiAgentConfig, AgentConfig
|
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
from app.services.draft_run_service import create_web_search_tool
|
from app.services.draft_run_service import create_web_search_tool
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -184,7 +187,7 @@ class AppChatService:
|
|||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.get("system_prompt", "")
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
system_prompt_rendered = render_prompt_message(
|
system_prompt_rendered = render_prompt_message(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
@@ -197,7 +200,7 @@ class AppChatService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
if knowledge_retrieval:
|
if knowledge_retrieval:
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||||
@@ -208,13 +211,13 @@ class AppChatService:
|
|||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
memory_config = config.get("memory", {})
|
memory_config = config.memory
|
||||||
if memory_config.get("enabled") and user_id:
|
if memory_config.get("enabled") and user_id:
|
||||||
memory_flag = True
|
memory_flag = True
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
web_tools = config.get("tools")
|
web_tools = config.tools
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
if web_search == True:
|
if web_search == True:
|
||||||
@@ -230,7 +233,7 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.get("model_parameters", {})
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -479,7 +482,9 @@ class AppChatService:
|
|||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: WorkflowConfig,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
@@ -488,281 +493,159 @@ class AppChatService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
|
workflow_service = WorkflowService(self.db)
|
||||||
|
|
||||||
start_time = time.time()
|
input_data = {"message":message, "variables": variables,
|
||||||
config_id = None
|
"conversation_id": str(conversation_id)}
|
||||||
|
inconfig = workflow_service.get_workflow_config(app_id)
|
||||||
|
|
||||||
if variables is None:
|
# 2. 创建执行记录
|
||||||
variables = {}
|
execution = workflow_service.create_execution(
|
||||||
|
workflow_config_id=inconfig.id,
|
||||||
|
app_id=app_id,
|
||||||
|
trigger_type="manual",
|
||||||
|
triggered_by=None,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
input_data=input_data
|
||||||
|
)
|
||||||
|
|
||||||
# 获取模型配置ID
|
# 3. 构建工作流配置字典
|
||||||
model_config_id = config.default_model_config_id
|
workflow_config_dict = {
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
"nodes": config.nodes,
|
||||||
# 处理系统提示词(支持变量替换)
|
"edges": config.edges,
|
||||||
system_prompt = config.get("system_prompt", "")
|
"variables": config.variables,
|
||||||
if variables:
|
"execution_config": config.execution_config
|
||||||
system_prompt_rendered = render_prompt_message(
|
}
|
||||||
system_prompt,
|
|
||||||
PromptMessageRole.USER,
|
# 4. 获取工作空间 ID(从 app 获取)
|
||||||
variables
|
|
||||||
|
# 5. 执行工作流
|
||||||
|
from app.core.workflow.executor import execute_workflow
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新状态为运行中
|
||||||
|
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
|
result = await execute_workflow(
|
||||||
|
workflow_config=workflow_config_dict,
|
||||||
|
input_data=input_data,
|
||||||
|
execution_id=execution.execution_id,
|
||||||
|
workspace_id=str(workspace_id),
|
||||||
|
user_id=user_id
|
||||||
)
|
)
|
||||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
|
||||||
|
|
||||||
# 准备工具列表
|
# 更新执行结果
|
||||||
tools = []
|
if result.get("status") == "completed":
|
||||||
|
workflow_service.update_execution_status(
|
||||||
# 添加知识库检索工具
|
execution.execution_id,
|
||||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
"completed",
|
||||||
if knowledge_retrieval:
|
output_data=result.get("node_outputs", {})
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
)
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
else:
|
||||||
if kb_ids:
|
workflow_service.update_execution_status(
|
||||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
execution.execution_id,
|
||||||
tools.append(kb_tool)
|
"failed",
|
||||||
|
error_message=result.get("error")
|
||||||
# 添加长期记忆工具
|
|
||||||
memory_flag = False
|
|
||||||
if memory == True:
|
|
||||||
memory_config = config.get("memory", {})
|
|
||||||
if memory_config.get("enabled") and user_id:
|
|
||||||
memory_flag = True
|
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
web_tools = config.get("tools")
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search == True:
|
|
||||||
if web_search_enable == True:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取模型参数
|
# 返回增强的响应结构
|
||||||
model_parameters = config.get("model_parameters", {})
|
return {
|
||||||
|
"execution_id": execution.execution_id,
|
||||||
|
"status": result.get("status"),
|
||||||
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
|
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
|
"error_message": result.get("error"),
|
||||||
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
|
"token_usage": result.get("token_usage")
|
||||||
|
}
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
except Exception as e:
|
||||||
agent = LangChainAgent(
|
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||||
model_name=api_key_obj.model_name,
|
workflow_service.update_execution_status(
|
||||||
api_key=api_key_obj.api_key,
|
execution.execution_id,
|
||||||
provider=api_key_obj.provider,
|
"failed",
|
||||||
api_base=api_key_obj.api_base,
|
error_message=str(e)
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
)
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
raise BusinessException(
|
||||||
system_prompt=system_prompt,
|
code=BizCode.INTERNAL_ERROR,
|
||||||
tools=tools,
|
message=f"工作流执行失败: {str(e)}"
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载历史消息
|
|
||||||
history = []
|
|
||||||
memory_config = {"enabled": True, 'max_history': 10}
|
|
||||||
if memory_config.get("enabled"):
|
|
||||||
messages = self.conversation_service.get_messages(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
limit=memory_config.get("max_history", 10)
|
|
||||||
)
|
)
|
||||||
history = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
# 调用 Agent
|
|
||||||
result = await agent.chat(
|
|
||||||
message=message,
|
|
||||||
history=history,
|
|
||||||
context=None,
|
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存消息
|
|
||||||
self.conversation_service.save_conversation_messages(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_message=message,
|
|
||||||
assistant_message=result["content"]
|
|
||||||
)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
|
|
||||||
return {
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"message": result["content"],
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}),
|
|
||||||
"elapsed_time": elapsed_time
|
|
||||||
}
|
|
||||||
|
|
||||||
async def workflow_chat_stream(
|
async def workflow_chat_stream(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: WorkflowConfig,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
workflow_service = WorkflowService(self.db)
|
||||||
|
input_data = {"message": message, "variables": variables,
|
||||||
|
"conversation_id": str(conversation_id)}
|
||||||
|
inconfig = workflow_service.get_workflow_config(app_id)
|
||||||
|
# 2. 创建执行记录
|
||||||
|
execution = workflow_service.create_execution(
|
||||||
|
workflow_config_id=inconfig.id,
|
||||||
|
app_id=app_id,
|
||||||
|
trigger_type="manual",
|
||||||
|
triggered_by=None,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
input_data=input_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建工作流配置字典
|
||||||
|
workflow_config_dict = {
|
||||||
|
"nodes": config.nodes,
|
||||||
|
"edges": config.edges,
|
||||||
|
"variables": config.variables,
|
||||||
|
"execution_config": config.execution_config
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 获取工作空间 ID(从 app 获取)
|
||||||
|
|
||||||
|
# 5. 流式执行工作流
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
# 更新状态为运行中
|
||||||
config_id = None
|
workflow_service.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
if variables is None:
|
|
||||||
variables = {}
|
|
||||||
|
|
||||||
# 获取模型配置ID
|
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||||
model_config_id = config.default_model_config_id
|
async for event in workflow_service._run_workflow_stream(
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
workflow_config=workflow_config_dict,
|
||||||
# 处理系统提示词(支持变量替换)
|
input_data=input_data,
|
||||||
system_prompt = config.get("system_prompt", "")
|
execution_id=execution.execution_id,
|
||||||
if variables:
|
workspace_id=str(workspace_id),
|
||||||
system_prompt_rendered = render_prompt_message(
|
user_id=user_id
|
||||||
system_prompt,
|
|
||||||
PromptMessageRole.USER,
|
|
||||||
variables
|
|
||||||
)
|
|
||||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
|
||||||
|
|
||||||
# 准备工具列表
|
|
||||||
tools = []
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
|
||||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
|
||||||
if knowledge_retrieval:
|
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
|
||||||
if kb_ids:
|
|
||||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
|
||||||
tools.append(kb_tool)
|
|
||||||
|
|
||||||
# 添加长期记忆工具
|
|
||||||
memory_flag = False
|
|
||||||
if memory:
|
|
||||||
memory_config = config.get("memory", {})
|
|
||||||
if memory_config.get("enabled") and user_id:
|
|
||||||
memory_flag = True
|
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
|
||||||
tools.append(memory_tool)
|
|
||||||
|
|
||||||
web_tools = config.get("tools")
|
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search == True:
|
|
||||||
if web_search_enable == True:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"已添加网络搜索工具",
|
|
||||||
extra={
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取模型参数
|
|
||||||
model_parameters = config.get("model_parameters", {})
|
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载历史消息
|
|
||||||
history = []
|
|
||||||
memory_config = {"enabled": True, 'max_history': 10}
|
|
||||||
if memory_config.get("enabled"):
|
|
||||||
messages = self.conversation_service.get_messages(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
limit=memory_config.get("max_history", 10)
|
|
||||||
)
|
|
||||||
history = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
# 发送开始事件
|
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
# 流式调用 Agent
|
|
||||||
full_content = ""
|
|
||||||
async for chunk in agent.chat_stream(
|
|
||||||
message=message,
|
|
||||||
history=history,
|
|
||||||
context=None,
|
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
):
|
):
|
||||||
full_content += chunk
|
# 直接转发 executor 的事件(已经是正确的格式)
|
||||||
# 发送消息块事件
|
yield event
|
||||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
|
|
||||||
# 保存消息
|
|
||||||
self.conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="user",
|
|
||||||
content=message
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="assistant",
|
|
||||||
content=full_content,
|
|
||||||
meta_data={
|
|
||||||
"model": api_key_obj.model_name,
|
|
||||||
"usage": {}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发送结束事件
|
|
||||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
|
||||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"流式聊天完成",
|
|
||||||
extra={
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"message_length": len(full_content)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except (GeneratorExit, asyncio.CancelledError):
|
|
||||||
# 生成器被关闭或任务被取消,正常退出
|
|
||||||
logger.debug("流式聊天被中断")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||||
|
workflow_service.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
# 发送错误事件
|
# 发送错误事件
|
||||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"execution_id": execution.execution_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
def get_app_chat_service(
|
def get_app_chat_service(
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.core.exceptions import (
|
|||||||
BusinessException,
|
BusinessException,
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
|
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
|
||||||
from app.models.app_model import AppStatus, AppType
|
from app.models.app_model import AppStatus, AppType
|
||||||
@@ -31,6 +32,7 @@ from app.schemas.workflow_schema import WorkflowConfigUpdate
|
|||||||
from app.services.agent_config_converter import AgentConfigConverter
|
from app.services.agent_config_converter import AgentConfigConverter
|
||||||
from app.models import AppShare, Workspace
|
from app.models import AppShare, Workspace
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
|
||||||
# 获取业务日志器
|
# 获取业务日志器
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -1225,6 +1227,26 @@ class AppService:
|
|||||||
"orchestration_mode": multi_agent_cfg.orchestration_mode
|
"orchestration_mode": multi_agent_cfg.orchestration_mode
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif app.type == AppType.WORKFLOW:
|
||||||
|
service = WorkflowService(self.db)
|
||||||
|
workflow_cfg = service.get_workflow_config(app_id)
|
||||||
|
if not workflow_cfg:
|
||||||
|
raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"nodes": workflow_cfg.nodes,
|
||||||
|
"edges": workflow_cfg.edges,
|
||||||
|
"variables": workflow_cfg.variables,
|
||||||
|
"execution_config": workflow_cfg.execution_config,
|
||||||
|
"triggers": workflow_cfg.triggers
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, errors = WorkflowValidator.validate_for_publish(config)
|
||||||
|
if not is_valid:
|
||||||
|
raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING)
|
||||||
|
logger.info(
|
||||||
|
"应用发布配置准备完成"
|
||||||
|
)
|
||||||
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
version = self._get_next_version(app_id)
|
version = self._get_next_version(app_id)
|
||||||
|
|||||||
@@ -1293,6 +1293,7 @@ class MultiAgentOrchestrator:
|
|||||||
conversation_id: 会话 ID
|
conversation_id: 会话 ID
|
||||||
user_id: 用户 ID
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行结果
|
执行结果
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -231,9 +231,9 @@ class PromptOptimizerService:
|
|||||||
if m:
|
if m:
|
||||||
prompt_index = m.start()
|
prompt_index = m.start()
|
||||||
prompt_finished = True
|
prompt_finished = True
|
||||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
yield {"content": buffer[idx:prompt_index]}
|
||||||
else:
|
else:
|
||||||
yield {"type": "delta", "content": cache[idx:]}
|
yield {"content": cache[idx:]}
|
||||||
if len(cache) != 0:
|
if len(cache) != 0:
|
||||||
idx = len(cache)
|
idx = len(cache)
|
||||||
|
|
||||||
@@ -249,8 +249,8 @@ class PromptOptimizerService:
|
|||||||
role=RoleType.ASSISTANT,
|
role=RoleType.ASSISTANT,
|
||||||
content=desc
|
content=desc
|
||||||
)
|
)
|
||||||
|
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parser_prompt_variables(prompt: str):
|
def parser_prompt_variables(prompt: str):
|
||||||
|
|||||||
@@ -344,14 +344,16 @@ class ToolService:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if operation_param:
|
if operation_param:
|
||||||
# 有多个操作
|
# 有多个操作,为每个操作生成具体参数
|
||||||
methods = []
|
methods = []
|
||||||
for operation in operation_param.enum:
|
for operation in operation_param.enum:
|
||||||
|
# 获取该操作的具体参数
|
||||||
|
operation_params = self._get_operation_specific_params(tool_instance, operation)
|
||||||
methods.append({
|
methods.append({
|
||||||
"method_id": f"{config.name}_{operation}",
|
"method_id": f"{config.name}_{operation}",
|
||||||
"name": operation,
|
"name": operation,
|
||||||
"description": f"{config.description} - {operation}",
|
"description": f"{config.description} - {operation}",
|
||||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
"parameters": operation_params
|
||||||
})
|
})
|
||||||
return methods
|
return methods
|
||||||
else:
|
else:
|
||||||
@@ -362,6 +364,243 @@ class ToolService:
|
|||||||
"description": config.description,
|
"description": config.description,
|
||||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取特定操作的参数列表"""
|
||||||
|
# 对于datetime_tool,根据操作类型返回相关参数
|
||||||
|
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
|
||||||
|
return self._get_datetime_tool_params(operation)
|
||||||
|
# 对于json_tool,根据操作类型返回相关参数
|
||||||
|
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
|
||||||
|
return self._get_json_tool_params(operation)
|
||||||
|
|
||||||
|
# 其他工具的默认处理:返回除operation外的所有参数
|
||||||
|
return [{
|
||||||
|
"name": param.name,
|
||||||
|
"type": param.type.value,
|
||||||
|
"description": param.description,
|
||||||
|
"required": param.required,
|
||||||
|
"default": param.default,
|
||||||
|
"enum": param.enum,
|
||||||
|
"minimum": param.minimum,
|
||||||
|
"maximum": param.maximum,
|
||||||
|
"pattern": param.pattern
|
||||||
|
} for param in tool_instance.parameters if param.name != "operation"]
|
||||||
|
|
||||||
|
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取datetime_tool特定操作的参数"""
|
||||||
|
if operation == "now":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "format":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "convert_timezone":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "from_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "timestamp_to_datetime":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# 默认返回所有参数(除了operation)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "from_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "calculation",
|
||||||
|
"type": "string",
|
||||||
|
"description": "时间计算表达式(如:+1d, -2h, +30m)",
|
||||||
|
"required": False
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取json_tool特定操作的参数"""
|
||||||
|
base_params = [
|
||||||
|
{
|
||||||
|
"name": "input_data",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if operation == "insert":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "new_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "新值(用于insert操作)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "replace":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "old_text",
|
||||||
|
"type": "string",
|
||||||
|
"description": "要替换的原文本(用于replace操作)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "new_text",
|
||||||
|
"type": "string",
|
||||||
|
"description": "替换后的新文本(用于replace操作)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "delete":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "parse":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return base_params
|
||||||
|
|
||||||
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
"""获取自定义工具的方法"""
|
"""获取自定义工具的方法"""
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ User Memory Service
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -22,7 +21,269 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Neo4j connector instan
|
# Neo4j connector instance for analytics functions
|
||||||
|
_neo4j_connector = Neo4jConnector()
|
||||||
|
|
||||||
|
# Default LLM ID for fallback
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Internal Helper Classes
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TagClassification(BaseModel):
|
||||||
|
"""Represents the classification of a tag into a specific domain."""
|
||||||
|
domain: str = Field(
|
||||||
|
...,
|
||||||
|
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||||
|
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_client_for_user(user_id: str):
|
||||||
|
"""
|
||||||
|
Get LLM client for a specific user based on their config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID to get config for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM client instance
|
||||||
|
"""
|
||||||
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
connected_config = get_end_user_connected_config(user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
|
||||||
|
if config_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(config_id)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
else:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryInsightHelper:
|
||||||
|
"""
|
||||||
|
Internal helper class for memory insight analysis.
|
||||||
|
Provides basic data retrieval and analysis functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, user_id: str):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.neo4j_connector = Neo4jConnector()
|
||||||
|
self.llm_client = _get_llm_client_for_user(user_id)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close database connection."""
|
||||||
|
await self.neo4j_connector.close()
|
||||||
|
|
||||||
|
async def get_domain_distribution(self) -> dict[str, float]:
|
||||||
|
"""Calculate the distribution of memory domains based on hot tags."""
|
||||||
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
|
|
||||||
|
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||||
|
if not hot_tags:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
domain_counts = Counter()
|
||||||
|
for tag, _ in hot_tags:
|
||||||
|
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||||
|
|
||||||
|
可选领域及其关键词:
|
||||||
|
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||||
|
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||||
|
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||||
|
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||||
|
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||||
|
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||||
|
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||||
|
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||||
|
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||||
|
- 其他:确实无法归入以上任何类别的内容
|
||||||
|
|
||||||
|
标签: {tag}
|
||||||
|
|
||||||
|
分析步骤:
|
||||||
|
1. 仔细理解标签的核心含义和使用场景
|
||||||
|
2. 对比各个领域的关键词,找到最匹配的领域
|
||||||
|
3. 特别注意:
|
||||||
|
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||||
|
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||||
|
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||||
|
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||||
|
|
||||||
|
请直接返回最合适的领域名称。"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
classification = await self.llm_client.response_structured(
|
||||||
|
messages=messages,
|
||||||
|
response_model=TagClassification,
|
||||||
|
)
|
||||||
|
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||||
|
domain_counts[classification.domain] += 1
|
||||||
|
|
||||||
|
total_tags = sum(domain_counts.values())
|
||||||
|
if total_tags == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
domain_distribution = {
|
||||||
|
domain: count / total_tags for domain, count in domain_counts.items()
|
||||||
|
}
|
||||||
|
return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True))
|
||||||
|
|
||||||
|
async def get_active_periods(self) -> list[int]:
|
||||||
|
"""
|
||||||
|
Identify the top 2 most active months for the user.
|
||||||
|
Only returns months if there is valid and diverse time data.
|
||||||
|
"""
|
||||||
|
query = """
|
||||||
|
MATCH (d:Dialogue)
|
||||||
|
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||||
|
RETURN d.created_at AS creation_time
|
||||||
|
"""
|
||||||
|
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return []
|
||||||
|
|
||||||
|
month_counts = Counter()
|
||||||
|
valid_dates_count = 0
|
||||||
|
for record in records:
|
||||||
|
creation_time_str = record.get("creation_time")
|
||||||
|
if not creation_time_str:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||||
|
month_counts[dt_object.month] += 1
|
||||||
|
valid_dates_count += 1
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not month_counts or valid_dates_count == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check if time distribution is too concentrated (likely batch imported data)
|
||||||
|
unique_months = len(month_counts)
|
||||||
|
if unique_months <= 2:
|
||||||
|
most_common_count = month_counts.most_common(1)[0][1]
|
||||||
|
if most_common_count / valid_dates_count > 0.8:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if unique_months >= 3:
|
||||||
|
most_common_months = month_counts.most_common(2)
|
||||||
|
return [month for month, _ in most_common_months]
|
||||||
|
|
||||||
|
if unique_months == 2:
|
||||||
|
counts = list(month_counts.values())
|
||||||
|
ratio = min(counts) / max(counts)
|
||||||
|
if ratio > 0.3:
|
||||||
|
most_common_months = month_counts.most_common(2)
|
||||||
|
return [month for month, _ in most_common_months]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_social_connections(self) -> dict | None:
|
||||||
|
"""Find the user with whom the most memories are shared."""
|
||||||
|
query = """
|
||||||
|
MATCH (c1:Chunk {group_id: $group_id})
|
||||||
|
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||||
|
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||||
|
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||||
|
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||||
|
WHERE common_statements > 0
|
||||||
|
RETURN other_user_id, common_statements
|
||||||
|
ORDER BY common_statements DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||||
|
if not records or not records[0].get("other_user_id"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
most_connected_user = records[0]["other_user_id"]
|
||||||
|
common_memories_count = records[0]["common_statements"]
|
||||||
|
|
||||||
|
time_range_query = """
|
||||||
|
MATCH (c:Chunk)
|
||||||
|
WHERE c.group_id IN [$user_id, $other_user_id]
|
||||||
|
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||||
|
"""
|
||||||
|
time_records = await self.neo4j_connector.execute_query(
|
||||||
|
time_range_query,
|
||||||
|
user_id=self.user_id,
|
||||||
|
other_user_id=most_connected_user
|
||||||
|
)
|
||||||
|
start_year, end_year = "N/A", "N/A"
|
||||||
|
if time_records and time_records[0]["start_time"]:
|
||||||
|
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||||
|
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": most_connected_user,
|
||||||
|
"common_memories_count": common_memories_count,
|
||||||
|
"time_range": f"{start_year}-{end_year}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UserSummaryHelper:
|
||||||
|
"""
|
||||||
|
Internal helper class for user summary generation.
|
||||||
|
Provides data retrieval functionality for user summary analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, user_id: str):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.connector = Neo4jConnector()
|
||||||
|
self.llm = _get_llm_client_for_user(user_id)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close database connection."""
|
||||||
|
await self.connector.close()
|
||||||
|
|
||||||
|
async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetch recent statements authored by the user/group for context."""
|
||||||
|
query = (
|
||||||
|
"MATCH (s:Statement) "
|
||||||
|
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||||
|
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||||
|
"ORDER BY created_at DESC LIMIT $limit"
|
||||||
|
)
|
||||||
|
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||||
|
records = []
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
records.append({
|
||||||
|
"statement": r.get("statement", ""),
|
||||||
|
"created_at": r.get("created_at")
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return records
|
||||||
|
|
||||||
|
async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||||
|
"""Get meaningful entities and their frequencies using hot tag logic."""
|
||||||
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
|
return await get_hot_memory_tags(self.user_id, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Service Class
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Service Class
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
class UserMemoryService:
|
class UserMemoryService:
|
||||||
@@ -601,7 +862,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
|
|||||||
生成记忆洞察报告(四个维度)
|
生成记忆洞察报告(四个维度)
|
||||||
|
|
||||||
这个函数包含完整的业务逻辑:
|
这个函数包含完整的业务逻辑:
|
||||||
1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联)
|
1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联)
|
||||||
2. 使用 Jinja2 模板渲染提示词
|
2. 使用 Jinja2 模板渲染提示词
|
||||||
3. 调用 LLM 生成四个维度的自然语言报告
|
3. 调用 LLM 生成四个维度的自然语言报告
|
||||||
4. 解析并返回四个部分
|
4. 解析并返回四个部分
|
||||||
@@ -620,7 +881,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
|
|||||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||||
import re
|
import re
|
||||||
|
|
||||||
insight = MemoryInsight(end_user_id)
|
insight = MemoryInsightHelper(end_user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 并行获取三个维度的数据
|
# 1. 并行获取三个维度的数据
|
||||||
@@ -722,7 +983,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
|||||||
生成用户摘要(包含四个部分)
|
生成用户摘要(包含四个部分)
|
||||||
|
|
||||||
这个函数包含完整的业务逻辑:
|
这个函数包含完整的业务逻辑:
|
||||||
1. 使用 UserSummary 工具类获取基础数据(实体、语句)
|
1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句)
|
||||||
2. 使用 prompt_utils 渲染提示词
|
2. 使用 prompt_utils 渲染提示词
|
||||||
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
|
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
|
||||||
|
|
||||||
@@ -737,20 +998,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
|||||||
"one_sentence": str
|
"one_sentence": str
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
from app.core.memory.analytics.user_summary import UserSummary
|
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 创建 UserSummary 实例
|
# 创建 UserSummaryHelper 实例
|
||||||
user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1) 收集上下文数据
|
# 1) 收集上下文数据
|
||||||
entities = await user_summary_tool._get_top_entities(limit=40)
|
entities = await user_summary_tool.get_top_entities(limit=40)
|
||||||
statements = await user_summary_tool._get_recent_statements(limit=100)
|
statements = await user_summary_tool.get_recent_statements(limit=100)
|
||||||
|
|
||||||
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
|
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
|
||||||
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
|
statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20]
|
||||||
|
|
||||||
# 2) 使用 prompt_utils 渲染提示词
|
# 2) 使用 prompt_utils 渲染提示词
|
||||||
user_prompt = await render_user_summary_prompt(
|
user_prompt = await render_user_summary_prompt(
|
||||||
@@ -794,6 +1054,28 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
|||||||
core_values = core_values_match.group(1).strip() if core_values_match else ""
|
core_values = core_values_match.group(1).strip() if core_values_match else ""
|
||||||
one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else ""
|
one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else ""
|
||||||
|
|
||||||
|
# 6) 清理可能包含的反思内容(防御性编程)
|
||||||
|
# 如果 LLM 仍然输出了反思内容,在这里过滤掉
|
||||||
|
def clean_reflection_content(text: str) -> str:
|
||||||
|
"""移除可能包含的反思内容"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
# 移除 "---" 之后的所有内容(通常是反思部分的开始)
|
||||||
|
if '---' in text:
|
||||||
|
text = text.split('---')[0].strip()
|
||||||
|
# 移除 "**Step" 开头的内容
|
||||||
|
if '**Step' in text:
|
||||||
|
text = text.split('**Step')[0].strip()
|
||||||
|
# 移除 "Self-Review" 相关内容
|
||||||
|
if 'Self-Review' in text or 'self-review' in text:
|
||||||
|
text = re.sub(r'[\-\*]*\s*Self-Review.*$', '', text, flags=re.IGNORECASE | re.DOTALL).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
user_summary = clean_reflection_content(user_summary)
|
||||||
|
personality = clean_reflection_content(personality)
|
||||||
|
core_values = clean_reflection_content(core_values)
|
||||||
|
one_sentence = clean_reflection_content(one_sentence)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"user_summary": user_summary,
|
"user_summary": user_summary,
|
||||||
"personality": personality,
|
"personality": personality,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.core.workflow.validator import validate_workflow_config
|
|||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
@@ -364,7 +365,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
execution.status = status
|
execution.status = status
|
||||||
if output_data is not None:
|
if output_data is not None:
|
||||||
execution.output_data = output_data
|
execution.output_data = convert_uuids_to_str(output_data)
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
execution.error_message = error_message
|
execution.error_message = error_message
|
||||||
if error_node_id is not None:
|
if error_node_id is not None:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import uuid
|
|||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.models import AppRelease
|
from app.models import AppRelease, WorkflowConfig
|
||||||
from app.models.agent_app_config_model import AgentConfig
|
from app.models.agent_app_config_model import AgentConfig
|
||||||
from app.models.multi_agent_model import MultiAgentConfig
|
from app.models.multi_agent_model import MultiAgentConfig
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ class AgentConfigProxy:
|
|||||||
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
|
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
|
||||||
|
|
||||||
config_dict = release.config
|
config_dict = release.config
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
app_id=release.app_id,
|
app_id=release.app_id,
|
||||||
system_prompt=config_dict.get("system_prompt"),
|
system_prompt=config_dict.get("system_prompt"),
|
||||||
@@ -45,10 +45,10 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
|
|||||||
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
|
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
|
||||||
|
|
||||||
config_dict = release.config
|
config_dict = release.config
|
||||||
|
|
||||||
|
|
||||||
agent_config = MultiAgentConfig(
|
agent_config = MultiAgentConfig(
|
||||||
app_id=release.app_id,
|
app_id=release.app_id,
|
||||||
default_model_config_id=release.default_model_config_id,
|
default_model_config_id=release.default_model_config_id,
|
||||||
model_parameters=config_dict.get("model_parameters"),
|
model_parameters=config_dict.get("model_parameters"),
|
||||||
master_agent_id=config_dict.get("master_agent_id"),
|
master_agent_id=config_dict.get("master_agent_id"),
|
||||||
@@ -58,11 +58,29 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
|
|||||||
routing_rules=config_dict.get("routing_rules"),
|
routing_rules=config_dict.get("routing_rules"),
|
||||||
execution_config=config_dict.get("execution_config", {}),
|
execution_config=config_dict.get("execution_config", {}),
|
||||||
aggregation_strategy=config_dict.get("aggregation_strategy", "merge"),
|
aggregation_strategy=config_dict.get("aggregation_strategy", "merge"),
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig:
|
||||||
|
|
||||||
|
config_dict = release.config
|
||||||
|
|
||||||
|
|
||||||
|
config = WorkflowConfig(
|
||||||
|
id=release.id,
|
||||||
|
app_id=release.app_id,
|
||||||
|
nodes=config_dict.get("nodes", []),
|
||||||
|
edges=config_dict.get("edges", []),
|
||||||
|
variables=config_dict.get("variables", []),
|
||||||
|
execution_config=config_dict.get("execution_config", {}),
|
||||||
|
triggers=config_dict.get("triggers", [])
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
|
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
|
||||||
"""Convert dict to MultiAgentConfig model object
|
"""Convert dict to MultiAgentConfig model object
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { AiPromptForm } from '@/views/ApplicationConfig/types'
|
import type { AiPromptForm } from '@/views/ApplicationConfig/types'
|
||||||
|
import { handleSSE, type SSEMessage } from '@/utils/stream'
|
||||||
|
|
||||||
export const createPromptSessions = () => {
|
export const createPromptSessions = () => {
|
||||||
return request.post(`/prompt/sessions`)
|
return request.post(`/prompt/sessions`)
|
||||||
@@ -7,6 +8,6 @@ export const createPromptSessions = () => {
|
|||||||
export const getPrompt = (session_id: string) => {
|
export const getPrompt = (session_id: string) => {
|
||||||
return request.get(`/prompt/sessions/${session_id}`)
|
return request.get(`/prompt/sessions/${session_id}`)
|
||||||
}
|
}
|
||||||
export const updatePromptMessages = (session_id: string, data: AiPromptForm) => {
|
export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => {
|
||||||
return request.post(`/prompt/sessions/${session_id}/messages`, data)
|
return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage)
|
||||||
}
|
}
|
||||||
BIN
web/src/assets/images/workflow/memory-read.png
Normal file
BIN
web/src/assets/images/workflow/memory-read.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 936 B |
BIN
web/src/assets/images/workflow/memory-write.png
Normal file
BIN
web/src/assets/images/workflow/memory-write.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 568 B |
@@ -1224,6 +1224,8 @@ export const en = {
|
|||||||
key_findings: 'Key Findings',
|
key_findings: 'Key Findings',
|
||||||
behavior_pattern: 'Behavior Pattern',
|
behavior_pattern: 'Behavior Pattern',
|
||||||
growth_trajectory: 'Growth Trajectory',
|
growth_trajectory: 'Growth Trajectory',
|
||||||
|
personality: 'Personality Traits',
|
||||||
|
core_values: 'Core Values',
|
||||||
},
|
},
|
||||||
space: {
|
space: {
|
||||||
createSpace: 'Create Space',
|
createSpace: 'Create Space',
|
||||||
@@ -1799,12 +1801,20 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
"not_contains": 'Does Not Contain',
|
"not_contains": 'Does Not Contain',
|
||||||
"startwith": 'Starts With',
|
"startwith": 'Starts With',
|
||||||
"endwith": 'Ends With',
|
"endwith": 'Ends With',
|
||||||
"eq": '==',
|
"eq": 'Equals',
|
||||||
"ne": '!=',
|
"ne": 'Not Equals',
|
||||||
"lt": '<',
|
num: {
|
||||||
"le": '<=',
|
"eq": '=',
|
||||||
"gt": '>',
|
"ne": '≠',
|
||||||
"ge": '>=',
|
"lt": '<',
|
||||||
|
"le": '≤',
|
||||||
|
"gt": '>',
|
||||||
|
"ge": '≥',
|
||||||
|
},
|
||||||
|
boolean: {
|
||||||
|
"eq": 'Is',
|
||||||
|
"ne": 'Is Not',
|
||||||
|
},
|
||||||
else_desc: 'Used to define the logic that should be executed when the if condition is not met.'
|
else_desc: 'Used to define the logic that should be executed when the if condition is not met.'
|
||||||
},
|
},
|
||||||
'http-request': {
|
'http-request': {
|
||||||
@@ -1845,12 +1855,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
loop: {
|
loop: {
|
||||||
cycle_vars: 'Loop Variables',
|
cycle_vars: 'Loop Variables',
|
||||||
condition: 'Loop Termination Condition',
|
condition: 'Loop Termination Condition',
|
||||||
|
max_loop: 'Maximum Loop Count',
|
||||||
},
|
},
|
||||||
assigner: {
|
assigner: {
|
||||||
assignments: 'Variables',
|
assignments: 'Variables',
|
||||||
cover: 'Overwrite',
|
cover: 'Override',
|
||||||
assign: 'Set',
|
assign: 'Set',
|
||||||
clear: 'Clear'
|
clear: 'Clear',
|
||||||
|
add: '+=',
|
||||||
|
subtract: '-=',
|
||||||
|
multiply: '*=',
|
||||||
|
divide: '/=',
|
||||||
},
|
},
|
||||||
iteration: {
|
iteration: {
|
||||||
input: 'Input Variable',
|
input: 'Input Variable',
|
||||||
|
|||||||
@@ -1305,6 +1305,8 @@ export const zh = {
|
|||||||
key_findings: '关键发现',
|
key_findings: '关键发现',
|
||||||
behavior_pattern: '行为模式',
|
behavior_pattern: '行为模式',
|
||||||
growth_trajectory: '成长轨迹',
|
growth_trajectory: '成长轨迹',
|
||||||
|
personality: '性格特点',
|
||||||
|
core_values: '核心价值观',
|
||||||
},
|
},
|
||||||
space: {
|
space: {
|
||||||
createSpace: '创建空间',
|
createSpace: '创建空间',
|
||||||
@@ -1899,12 +1901,20 @@ export const zh = {
|
|||||||
"not_contains": '不包含',
|
"not_contains": '不包含',
|
||||||
"startwith": '开始是',
|
"startwith": '开始是',
|
||||||
"endwith": '结束是',
|
"endwith": '结束是',
|
||||||
"eq": '==',
|
"eq": '是',
|
||||||
"ne": '!=',
|
"ne": '不是',
|
||||||
"lt": '<',
|
num: {
|
||||||
"le": '<=',
|
"eq": '=',
|
||||||
"gt": '>',
|
"ne": '≠',
|
||||||
"ge": '>=',
|
"lt": '<',
|
||||||
|
"le": '≤',
|
||||||
|
"gt": '>',
|
||||||
|
"ge": '≥',
|
||||||
|
},
|
||||||
|
boolean: {
|
||||||
|
"eq": '是',
|
||||||
|
"ne": '不是',
|
||||||
|
},
|
||||||
else_desc: '用于定义当 if 条件不满足时应执行的逻辑。'
|
else_desc: '用于定义当 if 条件不满足时应执行的逻辑。'
|
||||||
},
|
},
|
||||||
'http-request': {
|
'http-request': {
|
||||||
@@ -1945,12 +1955,17 @@ export const zh = {
|
|||||||
loop: {
|
loop: {
|
||||||
cycle_vars: '循环变量',
|
cycle_vars: '循环变量',
|
||||||
condition: '循环终止条件',
|
condition: '循环终止条件',
|
||||||
|
max_loop: '最大循环次数',
|
||||||
},
|
},
|
||||||
assigner: {
|
assigner: {
|
||||||
assignments: '变量',
|
assignments: '变量',
|
||||||
cover: '覆盖',
|
cover: '覆盖',
|
||||||
assign: '设置',
|
assign: '设置',
|
||||||
clear: '清空'
|
clear: '清空',
|
||||||
|
add: '+=',
|
||||||
|
subtract: '-=',
|
||||||
|
multiply: '*=',
|
||||||
|
divide: '/=',
|
||||||
},
|
},
|
||||||
iteration: {
|
iteration: {
|
||||||
input: '输入变量',
|
input: '输入变量',
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import ConversationEmptyIcon from '@/assets/images/conversation/conversationEmpt
|
|||||||
import type { ChatItem } from '@/components/Chat/types'
|
import type { ChatItem } from '@/components/Chat/types'
|
||||||
import CustomSelect from '@/components/CustomSelect'
|
import CustomSelect from '@/components/CustomSelect'
|
||||||
import AiPromptVariableModal from './AiPromptVariableModal'
|
import AiPromptVariableModal from './AiPromptVariableModal'
|
||||||
|
import { type SSEMessage } from '@/utils/stream'
|
||||||
|
import Editor from './Editor'
|
||||||
|
|
||||||
interface AiPromptModalProps {
|
interface AiPromptModalProps {
|
||||||
refresh: (value: string) => void;
|
refresh: (value: string) => void;
|
||||||
@@ -35,7 +37,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
|||||||
const [variables, setVariables] = useState<string[]>([])
|
const [variables, setVariables] = useState<string[]>([])
|
||||||
const [promptSession, setPromptSession] = useState<string | null>(null)
|
const [promptSession, setPromptSession] = useState<string | null>(null)
|
||||||
const aiPromptVariableModalRef = useRef<AiPromptVariableModalRef>(null)
|
const aiPromptVariableModalRef = useRef<AiPromptVariableModalRef>(null)
|
||||||
const currentPromptRef = useRef<any>(null)
|
const editorRef = useRef<any>(null)
|
||||||
|
const currentPromptValueRef = useRef<string>('')
|
||||||
|
|
||||||
const values = Form.useWatch([], form)
|
const values = Form.useWatch([], form)
|
||||||
|
|
||||||
@@ -78,16 +81,45 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
|||||||
setChatList(prev => {
|
setChatList(prev => {
|
||||||
return [...prev, { role: 'user', content: messageContent}]
|
return [...prev, { role: 'user', content: messageContent}]
|
||||||
})
|
})
|
||||||
form.setFieldsValue({ message: undefined })
|
form.setFieldsValue({ message: undefined, current_prompt: undefined })
|
||||||
updatePromptMessages(promptSession, values)
|
|
||||||
.then(res => {
|
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||||
const response = res as { prompt: string; desc: string; variables: string[] }
|
data.map(item => {
|
||||||
form.setFieldsValue({ current_prompt: response.prompt })
|
const { content, desc, variables } = item.data as { content: string; desc: string; variables: string[] };
|
||||||
setChatList(prev => {
|
|
||||||
return [...prev, { role: 'assistant', content: response.desc }]
|
switch (item.event) {
|
||||||
})
|
case 'start':
|
||||||
setVariables(response.variables)
|
currentPromptValueRef.current = ''
|
||||||
|
break;
|
||||||
|
case 'message':
|
||||||
|
if (content) {
|
||||||
|
currentPromptValueRef.current += content;
|
||||||
|
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
|
||||||
|
}
|
||||||
|
if (desc) {
|
||||||
|
setChatList(prev => {
|
||||||
|
return [...prev, { role: 'assistant', content: desc }]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (variables) {
|
||||||
|
setVariables(variables)
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 'end':
|
||||||
|
setLoading(false)
|
||||||
|
break
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
};
|
||||||
|
updatePromptMessages(promptSession, values, handleStreamMessage)
|
||||||
|
// .then(res => {
|
||||||
|
// const response = res as { prompt: string; desc: string; variables: string[] }
|
||||||
|
// form.setFieldsValue({ current_prompt: response.prompt })
|
||||||
|
// setChatList(prev => {
|
||||||
|
// return [...prev, { role: 'assistant', content: response.desc }]
|
||||||
|
// })
|
||||||
|
// setVariables(response.variables)
|
||||||
|
// })
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
})
|
})
|
||||||
@@ -101,18 +133,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
|||||||
aiPromptVariableModalRef.current?.handleOpen()
|
aiPromptVariableModalRef.current?.handleOpen()
|
||||||
}
|
}
|
||||||
const handleVariableApply = (value: string) => {
|
const handleVariableApply = (value: string) => {
|
||||||
const textArea = currentPromptRef.current?.resizableTextArea?.textArea
|
if (editorRef.current?.insertText) {
|
||||||
if (textArea) {
|
editorRef.current.insertText(value)
|
||||||
const cursorPosition = textArea.selectionStart
|
|
||||||
const currentValue = values.current_prompt || ''
|
|
||||||
const newValue = currentValue.slice(0, cursorPosition) + value + currentValue.slice(cursorPosition)
|
|
||||||
form.setFieldValue('current_prompt', newValue)
|
|
||||||
|
|
||||||
// 设置新的光标位置
|
|
||||||
setTimeout(() => {
|
|
||||||
textArea.focus()
|
|
||||||
textArea.setSelectionRange(cursorPosition + value.length, cursorPosition + value.length)
|
|
||||||
}, 0)
|
|
||||||
} else {
|
} else {
|
||||||
form.setFieldValue('current_prompt', (values.current_prompt || '') + value)
|
form.setFieldValue('current_prompt', (values.current_prompt || '') + value)
|
||||||
}
|
}
|
||||||
@@ -191,7 +213,11 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
|||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
<Form.Item name="current_prompt">
|
<Form.Item name="current_prompt">
|
||||||
<Input.TextArea ref={currentPromptRef} className="rb:bg-[#FBFDFF]! rb:h-100.5!" />
|
<Editor
|
||||||
|
ref={editorRef}
|
||||||
|
className="rb:h-100.5 "
|
||||||
|
onChange={(value) => form.setFieldValue('current_prompt', value)}
|
||||||
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
|
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
|
||||||
<Button block disabled={!values?.current_prompt} onClick={handleCopy}>{t('common.copy')}</Button>
|
<Button block disabled={!values?.current_prompt} onClick={handleCopy}>{t('common.copy')}</Button>
|
||||||
|
|||||||
91
web/src/views/ApplicationConfig/components/Editor/index.tsx
Normal file
91
web/src/views/ApplicationConfig/components/Editor/index.tsx
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import {forwardRef, useImperativeHandle } from 'react';
|
||||||
|
import clsx from 'clsx';
|
||||||
|
import { LexicalComposer } from '@lexical/react/LexicalComposer';
|
||||||
|
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
|
||||||
|
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||||
|
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
|
||||||
|
import { $getSelection } from 'lexical';
|
||||||
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
|
import InitialValuePlugin from './plugin/InitialValuePlugin'
|
||||||
|
import LineBreakPlugin from './plugin/LineBreakPlugin';
|
||||||
|
import InsertTextPlugin from './plugin/InsertTextPlugin';
|
||||||
|
|
||||||
|
export interface EditorRef {
|
||||||
|
insertText: (text: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface LexicalEditorProps {
|
||||||
|
className?: string;
|
||||||
|
placeholder?: string;
|
||||||
|
value?: string;
|
||||||
|
onChange?: (value: string) => void;
|
||||||
|
height?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const theme = {
|
||||||
|
paragraph: 'editor-paragraph',
|
||||||
|
text: {
|
||||||
|
bold: 'editor-text-bold',
|
||||||
|
italic: 'editor-text-italic',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
|
||||||
|
className = '',
|
||||||
|
value,
|
||||||
|
placeholder = "请输入内容...",
|
||||||
|
onChange,
|
||||||
|
}, ref) => {
|
||||||
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
|
useImperativeHandle(ref, () => ({
|
||||||
|
insertText: (text: string) => {
|
||||||
|
editor.update(() => {
|
||||||
|
const selection = $getSelection();
|
||||||
|
if (selection) {
|
||||||
|
selection.insertText(text);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}), [editor]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={{ position: 'relative' }}>
|
||||||
|
<RichTextPlugin
|
||||||
|
contentEditable={
|
||||||
|
<ContentEditable
|
||||||
|
className={clsx("rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto", className)}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
placeholder={
|
||||||
|
<div className="rb:absolute rb:px-4 rb:py-5 rb:text-[14px] rb:text-[#5B6167] rb:leading-5 rb:pointer-none">
|
||||||
|
{placeholder}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
ErrorBoundary={LexicalErrorBoundary}
|
||||||
|
/>
|
||||||
|
<LineBreakPlugin onChange={onChange} />
|
||||||
|
<InitialValuePlugin value={value} />
|
||||||
|
<InsertTextPlugin />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const Editor = forwardRef<EditorRef, LexicalEditorProps>((props, ref) => {
|
||||||
|
const initialConfig = {
|
||||||
|
namespace: 'Editor',
|
||||||
|
theme,
|
||||||
|
nodes: [],
|
||||||
|
onError: (error: Error) => {
|
||||||
|
console.error(error);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<LexicalComposer initialConfig={initialConfig}>
|
||||||
|
<EditorContent {...props} ref={ref} />
|
||||||
|
</LexicalComposer>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
export default Editor;
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
import { type FC, useEffect } from 'react';
|
||||||
|
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
|
||||||
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
|
|
||||||
|
// 设置初始值的插件
|
||||||
|
const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => {
|
||||||
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (value) {
|
||||||
|
editor.update(() => {
|
||||||
|
const root = $getRoot();
|
||||||
|
root.clear();
|
||||||
|
const paragraph = $createParagraphNode();
|
||||||
|
const textNode = $createTextNode(value);
|
||||||
|
paragraph.append(textNode);
|
||||||
|
root.append(paragraph);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [editor, value]);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default InitialValuePlugin
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
import { forwardRef, useImperativeHandle } from 'react';
|
||||||
|
import { $getSelection } from 'lexical';
|
||||||
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
|
import type { EditorRef } from '../index'
|
||||||
|
|
||||||
|
// 插入文本的插件
|
||||||
|
const InsertTextPlugin = forwardRef<EditorRef>((_, ref) => {
|
||||||
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
|
useImperativeHandle(ref, () => ({
|
||||||
|
insertText: (text: string) => {
|
||||||
|
editor.update(() => {
|
||||||
|
const selection = $getSelection();
|
||||||
|
if (selection) {
|
||||||
|
selection.insertText(text);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}), [editor]);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
export default InsertTextPlugin;
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
import { type FC, useEffect } from 'react';
|
||||||
|
import { $getRoot } from 'lexical';
|
||||||
|
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||||
|
|
||||||
|
// 处理换行的插件
|
||||||
|
const LineBreakPlugin: FC<{ onChange?: (value: string) => void }> = ({ onChange }) => {
|
||||||
|
const [editor] = useLexicalComposerContext();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return editor.registerUpdateListener(({ editorState }) => {
|
||||||
|
editorState.read(() => {
|
||||||
|
const root = $getRoot();
|
||||||
|
const textContent = root.getTextContent();
|
||||||
|
// 将\n转换为实际换行
|
||||||
|
const processedContent = textContent.replace(/\\n/g, '\n');
|
||||||
|
onChange?.(processedContent);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}, [editor, onChange]);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default LineBreakPlugin;
|
||||||
@@ -4,10 +4,9 @@ import {
|
|||||||
Col,
|
Col,
|
||||||
Tag,
|
Tag,
|
||||||
List,
|
List,
|
||||||
Space
|
Flex
|
||||||
} from 'antd';
|
} from 'antd';
|
||||||
import { EyeOutlined } from '@ant-design/icons';
|
import { EyeOutlined } from '@ant-design/icons';
|
||||||
import clsx from 'clsx'
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import dayjs, { type Dayjs } from 'dayjs'
|
import dayjs, { type Dayjs } from 'dayjs'
|
||||||
|
|
||||||
@@ -103,9 +102,9 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getS
|
|||||||
<div className="rb:h-full rb:flex rb:flex-col rb:justify-between">
|
<div className="rb:h-full rb:flex rb:flex-col rb:justify-between">
|
||||||
<div className="rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167]">
|
<div className="rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167]">
|
||||||
{t(`tool.${item.config_data.tool_class}_features`)} <br />
|
{t(`tool.${item.config_data.tool_class}_features`)} <br />
|
||||||
<Space size={4} className="rb:mt-2">
|
<Flex gap={4} wrap className="rb:mt-2 rb:w-full">
|
||||||
{InnerConfigData[item.config_data.tool_class].features.map(vo => <Tag key={vo} color="default">{ t(`tool.${vo}`) }</Tag>) }
|
{InnerConfigData[item.config_data.tool_class].features.map(vo => <Tag key={vo} color="default">{ t(`tool.${vo}`) }</Tag>) }
|
||||||
</Space>
|
</Flex>
|
||||||
|
|
||||||
{item.config_data.tool_class === 'DateTimeTool'
|
{item.config_data.tool_class === 'DateTimeTool'
|
||||||
? <div className="rb:mt-3 rb:bg-[#F0F3F8] rb:px-3 rb:py-2.5 rb:rounded-md">
|
? <div className="rb:mt-3 rb:bg-[#F0F3F8] rb:px-3 rb:py-2.5 rb:rounded-md">
|
||||||
|
|||||||
@@ -5,16 +5,25 @@ import { Skeleton } from 'antd';
|
|||||||
|
|
||||||
import RbCard from '@/components/RbCard/Card'
|
import RbCard from '@/components/RbCard/Card'
|
||||||
import Empty from '@/components/Empty';
|
import Empty from '@/components/Empty';
|
||||||
|
import RbAlert from '@/components/RbAlert';
|
||||||
import {
|
import {
|
||||||
getUserSummary,
|
getUserSummary,
|
||||||
} from '@/api/memory'
|
} from '@/api/memory'
|
||||||
import type { AboutMeRef } from '../types'
|
import type { AboutMeRef } from '../types'
|
||||||
|
|
||||||
|
|
||||||
|
interface Data {
|
||||||
|
user_summary: string;
|
||||||
|
personality: string;
|
||||||
|
core_values: string;
|
||||||
|
one_sentence: string;
|
||||||
|
[key: string]: string;
|
||||||
|
}
|
||||||
const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
|
const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { id } = useParams()
|
const { id } = useParams()
|
||||||
const [loading, setLoading] = useState<boolean>(false)
|
const [loading, setLoading] = useState<boolean>(false)
|
||||||
const [data, setData] = useState<string | null>(null)
|
const [data, setData] = useState<Data>({} as Data)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!id) return
|
if (!id) return
|
||||||
@@ -27,7 +36,7 @@ const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
|
|||||||
setLoading(true)
|
setLoading(true)
|
||||||
getUserSummary(id)
|
getUserSummary(id)
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
setData((res as { summary?: string }).summary || null)
|
setData((res as Data) || null)
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
@@ -44,10 +53,29 @@ const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
|
|||||||
>
|
>
|
||||||
{loading
|
{loading
|
||||||
? <Skeleton className="rb:mt-4" />
|
? <Skeleton className="rb:mt-4" />
|
||||||
: data
|
: Object.keys(data).filter(key => data[key] !== null).length > 0
|
||||||
? <div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
|
? <>
|
||||||
{data || '-'}
|
{data.user_summary &&
|
||||||
</div>
|
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
|
||||||
|
{data.user_summary}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
{data.personality && <>
|
||||||
|
<div className="rb:pt-4 rb:font-medium rb:leading-5 rb:mb-2">{t('userMemory.personality')}</div>
|
||||||
|
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
|
||||||
|
{data.personality}
|
||||||
|
</div>
|
||||||
|
</>}
|
||||||
|
{data.core_values && <>
|
||||||
|
<div className="rb:pt-4 rb:font-medium rb:leading-5 rb:mb-2">{t('userMemory.core_values')}</div>
|
||||||
|
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
|
||||||
|
{data.core_values}
|
||||||
|
</div>
|
||||||
|
</>}
|
||||||
|
{data.one_sentence &&
|
||||||
|
<RbAlert className="rb:mt-4">{data.one_sentence}</RbAlert>
|
||||||
|
}
|
||||||
|
</>
|
||||||
: <Empty size={88} className="rb:mt-12 rb:mb-20.25" />
|
: <Empty size={88} className="rb:mt-12 rb:mb-20.25" />
|
||||||
}
|
}
|
||||||
</RbCard>
|
</RbCard>
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ const ChatVariableModal = forwardRef<ChatVariableModalRef, ChatVariableModalProp
|
|||||||
const [form] = Form.useForm<ChatVariable>();
|
const [form] = Form.useForm<ChatVariable>();
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
const [editIndex, setEditIndex] = useState<number | undefined>(undefined)
|
const [editIndex, setEditIndex] = useState<number | undefined>(undefined)
|
||||||
const typeValue = Form.useWatch('type', form);
|
|
||||||
|
|
||||||
// 封装取消方法,添加关闭弹窗逻辑
|
// 封装取消方法,添加关闭弹窗逻辑
|
||||||
const handleClose = () => {
|
const handleClose = () => {
|
||||||
|
|||||||
@@ -14,18 +14,23 @@ const CharacterCountPlugin = ({ setCount, onChange }: { setCount: (count: number
|
|||||||
let serializedContent = '';
|
let serializedContent = '';
|
||||||
|
|
||||||
// Traverse all nodes and serialize properly
|
// Traverse all nodes and serialize properly
|
||||||
|
const paragraphs: string[] = [];
|
||||||
root.getChildren().forEach(child => {
|
root.getChildren().forEach(child => {
|
||||||
if ($isParagraphNode(child)) {
|
if ($isParagraphNode(child)) {
|
||||||
|
let paragraphContent = '';
|
||||||
child.getChildren().forEach(node => {
|
child.getChildren().forEach(node => {
|
||||||
if ($isVariableNode(node)) {
|
if ($isVariableNode(node)) {
|
||||||
serializedContent += node.getTextContent();
|
paragraphContent += node.getTextContent();
|
||||||
} else {
|
} else {
|
||||||
serializedContent += node.getTextContent();
|
paragraphContent += node.getTextContent();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
paragraphs.push(paragraphContent);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
serializedContent = paragraphs.join('\n');
|
||||||
|
|
||||||
setCount(serializedContent.length);
|
setCount(serializedContent.length);
|
||||||
onChange?.(serializedContent);
|
onChange?.(serializedContent);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
parts.forEach(part => {
|
parts.forEach(part => {
|
||||||
const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/);
|
const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/);
|
||||||
const contextMatch = part.match(/^\{\{context\}\}$/);
|
const contextMatch = part.match(/^\{\{context\}\}$/);
|
||||||
|
const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/);
|
||||||
|
|
||||||
// 匹配{{context}}格式
|
// 匹配{{context}}格式
|
||||||
if (contextMatch) {
|
if (contextMatch) {
|
||||||
@@ -38,6 +39,20 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 匹配{{conv.xx}}格式
|
||||||
|
if (conversationMatch) {
|
||||||
|
const [_, variableName] = conversationMatch;
|
||||||
|
const conversationSuggestion = options.find(s =>
|
||||||
|
s.group === 'CONVERSATION' && s.label === variableName
|
||||||
|
);
|
||||||
|
if (conversationSuggestion) {
|
||||||
|
paragraph.append($createVariableNode(conversationSuggestion));
|
||||||
|
} else {
|
||||||
|
paragraph.append($createTextNode(part));
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 匹配普通变量{{nodeId.label}}格式
|
// 匹配普通变量{{nodeId.label}}格式
|
||||||
if (match) {
|
if (match) {
|
||||||
const [_, nodeId, label] = match;
|
const [_, nodeId, label] = match;
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
const parentBBox = node.getBBox();
|
const parentBBox = node.getBBox();
|
||||||
const cycleId = data.cycle;
|
const cycleId = data.cycle;
|
||||||
|
|
||||||
|
const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
const newNode = graph.addNode({
|
const newNode = graph.addNode({
|
||||||
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
||||||
x: parentBBox.x,
|
x: parentBBox.x,
|
||||||
y: parentBBox.y,
|
y: parentBBox.y,
|
||||||
|
id,
|
||||||
data: {
|
data: {
|
||||||
id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
id,
|
||||||
type: selectedNodeType.type,
|
type: selectedNodeType.type,
|
||||||
icon: selectedNodeType.icon,
|
icon: selectedNodeType.icon,
|
||||||
name: t(`workflow.${selectedNodeType.type}`),
|
name: t(`workflow.${selectedNodeType.type}`),
|
||||||
|
|||||||
@@ -75,12 +75,15 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
const parentBBox = node.getBBox();
|
const parentBBox = node.getBBox();
|
||||||
const centerX = parentBBox.x + 24; // 默认节点宽度的一半
|
const centerX = parentBBox.x + 24; // 默认节点宽度的一半
|
||||||
const centerY = parentBBox.y + 50; // 默认节点高度的一半
|
const centerY = parentBBox.y + 50; // 默认节点高度的一半
|
||||||
|
|
||||||
|
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
const cycleStartNode = graph.addNode({
|
const cycleStartNode = graph.addNode({
|
||||||
...graphNodeLibrary.cycleStart,
|
...graphNodeLibrary.cycleStart,
|
||||||
x: centerX,
|
x: centerX,
|
||||||
y: centerY,
|
y: centerY,
|
||||||
|
id: cycleStartNodeId,
|
||||||
data: {
|
data: {
|
||||||
|
id: cycleStartNodeId,
|
||||||
type: 'cycle-start',
|
type: 'cycle-start',
|
||||||
parentId: node.id,
|
parentId: node.id,
|
||||||
isDefault: true, // 标记为默认节点,不可删除
|
isDefault: true, // 标记为默认节点,不可删除
|
||||||
|
|||||||
@@ -43,12 +43,14 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
const newY = sourceBBox.y;
|
const newY = sourceBBox.y;
|
||||||
|
|
||||||
// 创建新节点
|
// 创建新节点
|
||||||
|
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
const newNode = graph.addNode({
|
const newNode = graph.addNode({
|
||||||
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
||||||
x: newX,
|
x: newX,
|
||||||
y: newY,
|
y: newY,
|
||||||
|
id,
|
||||||
data: {
|
data: {
|
||||||
id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
id,
|
||||||
type: selectedNodeType.type,
|
type: selectedNodeType.type,
|
||||||
icon: selectedNodeType.icon,
|
icon: selectedNodeType.icon,
|
||||||
name: t(`workflow.${selectedNodeType.type}`),
|
name: t(`workflow.${selectedNodeType.type}`),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type FC } from 'react'
|
import { type FC } from 'react'
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { Form, Input, Button, Row, Col, Select } from 'antd'
|
import { Form, Input, Row, Col, Select, InputNumber, Radio } from 'antd'
|
||||||
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
|
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
|
||||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
import VariableSelect from '../VariableSelect'
|
import VariableSelect from '../VariableSelect'
|
||||||
@@ -11,6 +11,23 @@ interface AssignmentListProps {
|
|||||||
options: Suggestion[];
|
options: Suggestion[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const operationsObj = {
|
||||||
|
number: [
|
||||||
|
{ value: 'cover', label: 'workflow.config.assigner.cover' },
|
||||||
|
{ value: 'clear', label: 'workflow.config.assigner.clear' },
|
||||||
|
{ value: 'assign', label: 'workflow.config.assigner.assign' },
|
||||||
|
{ value: 'add', label: 'workflow.config.assigner.add' },
|
||||||
|
{ value: 'subtract', label: 'workflow.config.assigner.subtract' },
|
||||||
|
{ value: 'multiply', label: 'workflow.config.assigner.multiply' },
|
||||||
|
{ value: 'divide', label: 'workflow.config.assigner.divide' },
|
||||||
|
],
|
||||||
|
default: [
|
||||||
|
{ value: 'cover', label: 'workflow.config.assigner.cover' },
|
||||||
|
{ value: 'clear', label: 'workflow.config.assigner.clear' },
|
||||||
|
{ value: 'assign', label: 'workflow.config.assigner.assign' },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
const AssignmentList: FC<AssignmentListProps> = ({
|
const AssignmentList: FC<AssignmentListProps> = ({
|
||||||
parentName,
|
parentName,
|
||||||
options = [],
|
options = [],
|
||||||
@@ -27,6 +44,11 @@ const AssignmentList: FC<AssignmentListProps> = ({
|
|||||||
<PlusOutlined onClick={() => add({ operation: 'cover'})} />
|
<PlusOutlined onClick={() => add({ operation: 'cover'})} />
|
||||||
</div>
|
</div>
|
||||||
{fields.map(({ key, name, ...restField }) => {
|
{fields.map(({ key, name, ...restField }) => {
|
||||||
|
const variableSelector = form.getFieldValue([parentName, name, 'variable_selector']);
|
||||||
|
const selectedOption = options.find(option => `{{${option.value}}}` === variableSelector);
|
||||||
|
const dataType = selectedOption?.dataType;
|
||||||
|
const operationOptions = dataType === 'number' ? operationsObj.number : operationsObj.default;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={key} className="rb:mb-4">
|
<div key={key} className="rb:mb-4">
|
||||||
<Row gutter={12} className="rb:mb-2!">
|
<Row gutter={12} className="rb:mb-2!">
|
||||||
@@ -50,11 +72,10 @@ const AssignmentList: FC<AssignmentListProps> = ({
|
|||||||
noStyle
|
noStyle
|
||||||
>
|
>
|
||||||
<Select
|
<Select
|
||||||
options={[
|
options={operationOptions.map(op => ({
|
||||||
{ value: 'cover', label: t('workflow.config.assigner.cover') },
|
...op,
|
||||||
{ value: 'clear', label: t('workflow.config.assigner.clear') },
|
label: t(op.label)
|
||||||
{ value: 'assign', label: t('workflow.config.assigner.assign') },
|
}))}
|
||||||
]}
|
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
onChange={() => {
|
onChange={() => {
|
||||||
form.setFieldValue([parentName, name, 'value'], undefined);
|
form.setFieldValue([parentName, name, 'value'], undefined);
|
||||||
@@ -77,20 +98,31 @@ const AssignmentList: FC<AssignmentListProps> = ({
|
|||||||
{...restField}
|
{...restField}
|
||||||
name={[name, 'value']}
|
name={[name, 'value']}
|
||||||
noStyle
|
noStyle
|
||||||
rules={[{ required: true, message: 'Missing last name' }]}
|
|
||||||
>
|
>
|
||||||
{operation === 'assign' ? (
|
{operation === 'assign'
|
||||||
<Input.TextArea
|
? <>
|
||||||
placeholder={t('common.pleaseEnter')}
|
{dataType === 'number'
|
||||||
rows={3}
|
? <InputNumber
|
||||||
/>
|
placeholder={t('common.pleaseEnter')}
|
||||||
) : (
|
className="rb:w-full!"
|
||||||
<VariableSelect
|
/>
|
||||||
|
: dataType === 'boolean'
|
||||||
|
? <Radio.Group block>
|
||||||
|
<Radio.Button value={true}>True</Radio.Button>
|
||||||
|
<Radio.Button value={false}>False</Radio.Button>
|
||||||
|
</Radio.Group>
|
||||||
|
: <Input.TextArea
|
||||||
|
placeholder={t('common.pleaseEnter')}
|
||||||
|
rows={3}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
</>
|
||||||
|
: <VariableSelect
|
||||||
placeholder={t('common.pleaseSelect')}
|
placeholder={t('common.pleaseSelect')}
|
||||||
options={options}
|
options={dataType ? options.filter(vo => vo.dataType === dataType) : options}
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
/>
|
/>
|
||||||
)}
|
}
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { type FC } from 'react'
|
import { type FC } from 'react'
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { Form, Button, Select, Space, Row, Col, Divider } from 'antd'
|
import { Form, Button, Select, Space, Row, Col, Divider, InputNumber, Radio, type SelectProps } from 'antd'
|
||||||
import { DeleteOutlined } from '@ant-design/icons';
|
import { DeleteOutlined } from '@ant-design/icons';
|
||||||
|
|
||||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
@@ -9,37 +9,48 @@ import VariableSelect from '../VariableSelect'
|
|||||||
import Editor from '../../Editor'
|
import Editor from '../../Editor'
|
||||||
|
|
||||||
interface CaseListProps {
|
interface CaseListProps {
|
||||||
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; }[] }>;
|
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; input_type?: string; }[] }>;
|
||||||
onChange?: (value: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; }[] }>) => void;
|
onChange?: (value: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; }[] }>) => void;
|
||||||
options: Suggestion[];
|
options: Suggestion[];
|
||||||
name: string;
|
name: string;
|
||||||
selectedNode?: any;
|
selectedNode?: any;
|
||||||
graphRef?: any;
|
graphRef?: any;
|
||||||
}
|
}
|
||||||
const operatorList = [
|
const operatorsObj: { [key: string]: SelectProps['options'] } = {
|
||||||
"empty",
|
default: [
|
||||||
"not_empty",
|
{ value: 'empty', label: 'workflow.config.if-else.empty' },
|
||||||
"contains",
|
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
|
||||||
"not_contains",
|
{ value: 'contains', label: 'workflow.config.if-else.contains' },
|
||||||
"startwith",
|
{ value: 'not_contains', label: 'workflow.config.if-else.not_contains' },
|
||||||
"endwith",
|
{ value: 'startwith', label: 'workflow.config.if-else.startwith' },
|
||||||
"eq",
|
{ value: 'endwith', label: 'workflow.config.if-else.endwith' },
|
||||||
"ne",
|
{ value: 'eq', label: 'workflow.config.if-else.eq' },
|
||||||
"lt",
|
{ value: 'ne', label: 'workflow.config.if-else.ne' },
|
||||||
"le",
|
],
|
||||||
"gt",
|
number: [
|
||||||
"ge"
|
{ value: 'eq', label: 'workflow.config.if-else.num.eq' },
|
||||||
]
|
{ value: 'ne', label: 'workflow.config.if-else.num.ne' },
|
||||||
|
{ value: 'lt', label: 'workflow.config.if-else.num.lt' },
|
||||||
|
{ value: 'le', label: 'workflow.config.if-else.num.le' },
|
||||||
|
{ value: 'gt', label: 'workflow.config.if-else.num.gt' },
|
||||||
|
{ value: 'ge', label: 'workflow.config.if-else.num.ge' },
|
||||||
|
{ value: 'empty', label: 'workflow.config.if-else.empty' },
|
||||||
|
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
|
||||||
|
],
|
||||||
|
boolean: [
|
||||||
|
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
|
||||||
|
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
const CaseList: FC<CaseListProps> = ({
|
const CaseList: FC<CaseListProps> = ({
|
||||||
value = [],
|
|
||||||
options,
|
options,
|
||||||
name,
|
name,
|
||||||
onChange,
|
|
||||||
selectedNode,
|
selectedNode,
|
||||||
graphRef
|
graphRef
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const form = Form.useFormInstance();
|
||||||
|
|
||||||
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
||||||
if (!selectedNode || !graphRef?.current) return;
|
if (!selectedNode || !graphRef?.current) return;
|
||||||
@@ -175,29 +186,49 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
});
|
});
|
||||||
}, 50);
|
}, 50);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleChangeLogicalOperator = (index: number) => {
|
const handleChangeLogicalOperator = (index: number) => {
|
||||||
const newValue = [...value]
|
const currentValue = form.getFieldValue([name, index, 'logical_operator']);
|
||||||
newValue[index] = {
|
form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
|
||||||
...newValue[index],
|
};
|
||||||
logical_operator: newValue[index].logical_operator === 'and' ? 'or' : 'and'
|
|
||||||
}
|
const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string) => {
|
||||||
onChange && onChange(newValue)
|
form.setFieldsValue({
|
||||||
}
|
[name]: {
|
||||||
|
[caseIndex]: {
|
||||||
|
expressions: {
|
||||||
|
[conditionIndex]: {
|
||||||
|
left: newValue,
|
||||||
|
comparison_operator: undefined,
|
||||||
|
right: undefined,
|
||||||
|
input_type: undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
const handleAddCase = (addCaseFunc: Function) => {
|
const handleAddCase = (addCaseFunc: Function) => {
|
||||||
addCaseFunc({ logical_operator: 'and', expressions: [] });
|
addCaseFunc({ logical_operator: 'and', expressions: [] });
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
updateNodePorts((value?.length || 0) + 1);
|
const currentCases = form.getFieldValue(name) || [];
|
||||||
|
updateNodePorts(currentCases.length);
|
||||||
}, 100);
|
}, 100);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleRemoveCase = (removeCaseFunc: Function, fieldName: number, caseIndex: number) => {
|
const handleRemoveCase = (removeCaseFunc: Function, fieldName: number, caseIndex: number) => {
|
||||||
removeCaseFunc(fieldName);
|
removeCaseFunc(fieldName);
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
updateNodePorts((value?.length || 1) - 1, caseIndex);
|
const currentCases = form.getFieldValue(name) || [];
|
||||||
|
updateNodePorts(currentCases.length, caseIndex);
|
||||||
}, 100);
|
}, 100);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleInputTypeChange = (caseIndex: number, conditionIndex: number) => {
|
||||||
|
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], undefined);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form.List name={name}>
|
<Form.List name={name}>
|
||||||
@@ -218,7 +249,7 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
<Space>
|
<Space>
|
||||||
<Button
|
<Button
|
||||||
type="dashed"
|
type="dashed"
|
||||||
onClick={() => addCondition()}
|
onClick={() => addCondition({})}
|
||||||
size="small"
|
size="small"
|
||||||
>
|
>
|
||||||
+ {t('workflow.config.addCase')}
|
+ {t('workflow.config.addCase')}
|
||||||
@@ -234,15 +265,23 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-15 rb:bottom-6 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
|
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-15 rb:bottom-6 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
|
||||||
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
|
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
|
||||||
<Form.Item name={[caseField.name, 'logical_operator']} noStyle >
|
<Form.Item name={[caseField.name, 'logical_operator']} noStyle >
|
||||||
<Button size="small" className="rb:cursor-pointer" onClick={() => handleChangeLogicalOperator(caseIndex)}>{value?.[caseIndex].logical_operator}</Button>
|
<Button size="small" className="rb:cursor-pointer" onClick={() => handleChangeLogicalOperator(caseIndex)}>{logicalOperator}</Button>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
{conditionFields.map((conditionField, conditionIndex) => {
|
{conditionFields.map((conditionField, conditionIndex) => {
|
||||||
const currentOperator = value?.[caseIndex]?.expressions?.[conditionIndex]?.comparison_operator;
|
const cases = form.getFieldValue(name) || [];
|
||||||
|
const currentCase = cases[caseIndex] || {};
|
||||||
|
const currentExpression = currentCase.expressions?.[conditionIndex] || {};
|
||||||
|
const currentOperator = currentExpression.comparison_operator;
|
||||||
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
|
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
|
||||||
|
const leftFieldValue = currentExpression.left;
|
||||||
|
const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue);
|
||||||
|
const leftFieldType = leftFieldOption?.dataType;
|
||||||
|
const operatorList = operatorsObj[leftFieldType || 'default'] || operatorsObj.default || [];
|
||||||
|
const inputType = leftFieldType === 'number' ? currentExpression.input_type : undefined;
|
||||||
|
const logicalOperator = currentCase.logical_operator;
|
||||||
return (
|
return (
|
||||||
<div key={conditionField.key} className={clsx({
|
<div key={conditionField.key} className={clsx({
|
||||||
"rb:mb-3": conditionIndex !== conditionFields.length - 1
|
"rb:mb-3": conditionIndex !== conditionFields.length - 1
|
||||||
@@ -257,18 +296,20 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
size="small"
|
size="small"
|
||||||
allowClear={false}
|
allowClear={false}
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
|
onChange={(val) => handleLeftFieldChange(caseIndex, conditionIndex, val)}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Col>
|
</Col>
|
||||||
<Col span={8}>
|
<Col span={8}>
|
||||||
<Form.Item name={[conditionField.name, 'comparison_operator']} noStyle>
|
<Form.Item name={[conditionField.name, 'comparison_operator']} noStyle>
|
||||||
<Select
|
<Select
|
||||||
options={operatorList.map(key => ({
|
options={operatorList.map(vo => ({
|
||||||
value: key,
|
...vo,
|
||||||
label: t(`workflow.config.if-else.${key}`)
|
label: t(String(vo?.label || ''))
|
||||||
}))}
|
}))}
|
||||||
size="small"
|
size="small"
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
|
placeholder={t('common.pleaseSelect')}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Col>
|
</Col>
|
||||||
@@ -280,11 +321,48 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
|
||||||
{!hideRightField && (
|
{!hideRightField && <>
|
||||||
<Form.Item name={[conditionField.name, 'right']} noStyle>
|
{leftFieldType === 'number'
|
||||||
<Editor options={options} />
|
? <Row>
|
||||||
</Form.Item>
|
<Col span={12}>
|
||||||
)}
|
<Form.Item name={[conditionField.name, 'input_type']} noStyle>
|
||||||
|
<Select
|
||||||
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
options={[{ value: 'Variable', label: 'Variable' }, { value: 'Constant', label: 'Constant' }]}
|
||||||
|
popupMatchSelectWidth={false}
|
||||||
|
variant="borderless"
|
||||||
|
onChange={() => handleInputTypeChange(caseIndex, conditionIndex)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name={[conditionField.name, 'right']} noStyle>
|
||||||
|
{inputType === 'Variable'
|
||||||
|
?
|
||||||
|
<VariableSelect
|
||||||
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
options={options.filter(vo => vo.dataType === 'number')}
|
||||||
|
allowClear={false}
|
||||||
|
popupMatchSelectWidth={false}
|
||||||
|
variant="borderless"
|
||||||
|
/>
|
||||||
|
: <InputNumber placeholder={t('common.pleaseEnter')}
|
||||||
|
variant="borderless" className="rb:w-full!" />
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
: <Form.Item name={[conditionField.name, 'right']} noStyle>
|
||||||
|
{leftFieldType === 'boolean'
|
||||||
|
? <Radio.Group block>
|
||||||
|
<Radio.Button value={true}>True</Radio.Button>
|
||||||
|
<Radio.Button value={false}>False</Radio.Button>
|
||||||
|
</Radio.Group>
|
||||||
|
: <Editor options={options} />
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
}
|
||||||
|
</>}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import { type FC } from 'react'
|
import { type FC } from 'react'
|
||||||
import clsx from 'clsx'
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { Form, Button, Select, Space, Row, Col, Divider } from 'antd'
|
import { Form, Button, Select, Row, Col, InputNumber, Radio, type SelectProps } from 'antd'
|
||||||
import { DeleteOutlined } from '@ant-design/icons';
|
import { DeleteOutlined } from '@ant-design/icons';
|
||||||
|
|
||||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
@@ -10,7 +9,7 @@ import Editor from '../../Editor'
|
|||||||
|
|
||||||
interface Case {
|
interface Case {
|
||||||
logical_operator: 'and' | 'or';
|
logical_operator: 'and' | 'or';
|
||||||
expressions: Array<{ left: string; comparison_operator: string; right: string; }>
|
expressions: Array<{ left: string; comparison_operator: string; right: string; input_type: string; }>
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CaseListProps {
|
interface CaseListProps {
|
||||||
@@ -22,36 +21,63 @@ interface CaseListProps {
|
|||||||
graphRef?: any;
|
graphRef?: any;
|
||||||
addBtnText?: string;
|
addBtnText?: string;
|
||||||
}
|
}
|
||||||
const operatorList = [
|
const operatorsObj: { [key: string]: SelectProps['options'] } = {
|
||||||
"empty",
|
default: [
|
||||||
"not_empty",
|
{ value: 'empty', label: 'workflow.config.if-else.empty' },
|
||||||
"contains",
|
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
|
||||||
"not_contains",
|
{ value: 'contains', label: 'workflow.config.if-else.contains' },
|
||||||
"startwith",
|
{ value: 'not_contains', label: 'workflow.config.if-else.not_contains' },
|
||||||
"endwith",
|
{ value: 'startwith', label: 'workflow.config.if-else.startwith' },
|
||||||
"eq",
|
{ value: 'endwith', label: 'workflow.config.if-else.endwith' },
|
||||||
"ne",
|
{ value: 'eq', label: 'workflow.config.if-else.eq' },
|
||||||
"lt",
|
{ value: 'ne', label: 'workflow.config.if-else.ne' },
|
||||||
"le",
|
],
|
||||||
"gt",
|
number: [
|
||||||
"ge"
|
{ value: 'eq', label: 'workflow.config.if-else.num.eq' },
|
||||||
]
|
{ value: 'ne', label: 'workflow.config.if-else.num.ne' },
|
||||||
|
{ value: 'lt', label: 'workflow.config.if-else.num.lt' },
|
||||||
|
{ value: 'le', label: 'workflow.config.if-else.num.le' },
|
||||||
|
{ value: 'gt', label: 'workflow.config.if-else.num.gt' },
|
||||||
|
{ value: 'ge', label: 'workflow.config.if-else.num.ge' },
|
||||||
|
{ value: 'empty', label: 'workflow.config.if-else.empty' },
|
||||||
|
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
|
||||||
|
],
|
||||||
|
boolean: [
|
||||||
|
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
|
||||||
|
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
const ConditionList: FC<CaseListProps> = ({
|
const ConditionList: FC<CaseListProps> = ({
|
||||||
value,
|
|
||||||
options,
|
options,
|
||||||
parentName,
|
parentName,
|
||||||
onChange,
|
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const form = Form.useFormInstance();
|
||||||
|
|
||||||
|
const handleLeftFieldChange = (index: number, newValue: string) => {
|
||||||
|
form.setFieldsValue({
|
||||||
|
[parentName]: {
|
||||||
|
expressions: {
|
||||||
|
[index]: {
|
||||||
|
left: newValue,
|
||||||
|
comparison_operator: undefined,
|
||||||
|
right: undefined,
|
||||||
|
input_type: undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleInputTypeChange = (index: number) => {
|
||||||
|
form.setFieldValue([parentName, 'expressions', index, 'right'], undefined);
|
||||||
|
};
|
||||||
|
|
||||||
const handleChangeLogicalOperator = () => {
|
const handleChangeLogicalOperator = () => {
|
||||||
if (!value) return;
|
const currentValue = form.getFieldValue([parentName, 'logical_operator']);
|
||||||
onChange && onChange({
|
form.setFieldValue([parentName, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
|
||||||
logical_operator: value.logical_operator === 'and' ? 'or' : 'and',
|
};
|
||||||
expressions: value.expressions || []
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form.List name={[parentName, 'expressions']}>
|
<Form.List name={[parentName, 'expressions']}>
|
||||||
@@ -59,8 +85,16 @@ const ConditionList: FC<CaseListProps> = ({
|
|||||||
<div>
|
<div>
|
||||||
<div className="rb:relative">
|
<div className="rb:relative">
|
||||||
{fields.map((field, index) => {
|
{fields.map((field, index) => {
|
||||||
const currentOperator = value?.expressions?.[index]?.comparison_operator;
|
const expressions = form.getFieldValue([parentName, 'expressions']) || [];
|
||||||
|
const currentExpression = expressions[index] || {};
|
||||||
|
const currentOperator = currentExpression.comparison_operator;
|
||||||
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
|
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
|
||||||
|
const leftFieldValue = currentExpression.left;
|
||||||
|
const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue);
|
||||||
|
const leftFieldType = leftFieldOption?.dataType;
|
||||||
|
const operatorList = operatorsObj[leftFieldType || 'default'] || operatorsObj.default || [];
|
||||||
|
const inputType = leftFieldType === 'number' ? currentExpression.input_type : undefined;
|
||||||
|
const logicalOperator = form.getFieldValue([parentName, 'logical_operator']);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={field.key} className="rb:mb-3">
|
<div key={field.key} className="rb:mb-3">
|
||||||
@@ -68,7 +102,7 @@ const ConditionList: FC<CaseListProps> = ({
|
|||||||
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-3.75 rb:bottom-3.75 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
|
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-3.75 rb:bottom-3.75 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
|
||||||
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
|
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
|
||||||
<Form.Item name={[parentName, 'logical_operator']} noStyle >
|
<Form.Item name={[parentName, 'logical_operator']} noStyle >
|
||||||
<Button size="small" className="rb:cursor-pointer" onClick={handleChangeLogicalOperator}>{value?.logical_operator}</Button>
|
<Button size="small" className="rb:cursor-pointer" onClick={handleChangeLogicalOperator}>{logicalOperator}</Button>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</div>
|
</div>
|
||||||
</>)}
|
</>)}
|
||||||
@@ -82,6 +116,7 @@ const ConditionList: FC<CaseListProps> = ({
|
|||||||
size="small"
|
size="small"
|
||||||
allowClear={false}
|
allowClear={false}
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
|
onChange={(val) => handleLeftFieldChange(index, val)}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Col>
|
</Col>
|
||||||
@@ -89,9 +124,9 @@ const ConditionList: FC<CaseListProps> = ({
|
|||||||
<Col span={8}>
|
<Col span={8}>
|
||||||
<Form.Item name={[field.name, 'comparison_operator']} noStyle>
|
<Form.Item name={[field.name, 'comparison_operator']} noStyle>
|
||||||
<Select
|
<Select
|
||||||
options={operatorList.map(key => ({
|
options={operatorList.map(vo => ({
|
||||||
value: key,
|
...vo,
|
||||||
label: t(`workflow.config.if-else.${key}`)
|
label: t(String(vo?.label || ''))
|
||||||
}))}
|
}))}
|
||||||
size="small"
|
size="small"
|
||||||
popupMatchSelectWidth={false}
|
popupMatchSelectWidth={false}
|
||||||
@@ -104,14 +139,53 @@ const ConditionList: FC<CaseListProps> = ({
|
|||||||
onClick={() => remove(field.name)}
|
onClick={() => remove(field.name)}
|
||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
|
|
||||||
{!hideRightField && (
|
{!hideRightField && <>
|
||||||
<Col span={24}>
|
{leftFieldType === 'number'
|
||||||
<Form.Item name={[field.name, 'right']} noStyle>
|
? <Col span={24}><Row>
|
||||||
<Editor options={options} />
|
<Col span={12}>
|
||||||
</Form.Item>
|
<Form.Item name={[field.name, 'input_type']} noStyle>
|
||||||
</Col>
|
<Select
|
||||||
)}
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
options={[{ value: 'Variable', label: 'Variable' }, { value: 'Constant', label: 'Constant' }]}
|
||||||
|
popupMatchSelectWidth={false}
|
||||||
|
variant="borderless"
|
||||||
|
className="rb:w-full!"
|
||||||
|
onChange={() => handleInputTypeChange(index)}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={12}>
|
||||||
|
<Form.Item name={[field.name, 'right']} noStyle>
|
||||||
|
{inputType === 'Variable'
|
||||||
|
?
|
||||||
|
<VariableSelect
|
||||||
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
options={options.filter(vo => vo.dataType === 'number')}
|
||||||
|
allowClear={false}
|
||||||
|
popupMatchSelectWidth={false}
|
||||||
|
variant="borderless"
|
||||||
|
className="rb:w-full!"
|
||||||
|
/>
|
||||||
|
: <InputNumber placeholder={t('common.pleaseEnter')}
|
||||||
|
variant="borderless" className="rb:w-full!" />
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
</Row></Col>
|
||||||
|
: <Col span={24}>
|
||||||
|
<Form.Item name={[field.name, 'right']} noStyle>
|
||||||
|
{leftFieldType === 'boolean'
|
||||||
|
? <Radio.Group block>
|
||||||
|
<Radio.Button value={true}>True</Radio.Button>
|
||||||
|
<Radio.Button value={false}>False</Radio.Button>
|
||||||
|
</Radio.Group>
|
||||||
|
: <Editor options={options} />
|
||||||
|
}
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
}
|
||||||
|
</>}
|
||||||
|
|
||||||
</Row>
|
</Row>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
|||||||
label: `${childData.name || childData.type}.${key}`,
|
label: `${childData.name || childData.type}.${key}`,
|
||||||
type: 'output',
|
type: 'output',
|
||||||
dataType: 'string',
|
dataType: 'string',
|
||||||
value: `{{${childData.id}.${key}}}`,
|
value: `${childData.id}.${key}`,
|
||||||
nodeData: childData
|
nodeData: childData
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
|||||||
<Row gutter={12} className="rb:mb-2!">
|
<Row gutter={12} className="rb:mb-2!">
|
||||||
<Col span={12}>
|
<Col span={12}>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
name={[name,0, 'key']}
|
|
||||||
noStyle
|
noStyle
|
||||||
>
|
>
|
||||||
{t('workflow.config.var-aggregator.variable')}
|
{t('workflow.config.var-aggregator.variable')}
|
||||||
@@ -34,9 +33,8 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
|||||||
</Row>
|
</Row>
|
||||||
|
|
||||||
<Form.Item
|
<Form.Item
|
||||||
name={[name, 0, 'value']}
|
name={name}
|
||||||
noStyle
|
noStyle
|
||||||
rules={[{ required: true, message: 'Missing last name' }]}
|
|
||||||
>
|
>
|
||||||
<VariableSelect
|
<VariableSelect
|
||||||
placeholder={t('common.pleaseSelect')}
|
placeholder={t('common.pleaseSelect')}
|
||||||
@@ -76,7 +74,6 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
|
|||||||
{...restField}
|
{...restField}
|
||||||
name={[name, 'value']}
|
name={[name, 'value']}
|
||||||
noStyle
|
noStyle
|
||||||
rules={[{ required: true, message: 'Missing last name' }]}
|
|
||||||
>
|
>
|
||||||
<VariableSelect
|
<VariableSelect
|
||||||
placeholder={t('common.pleaseSelect')}
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useState, useEffect } from 'react';
|
import { useState, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Button, Select, Table } from 'antd';
|
import { Button, Select, Table } from 'antd';
|
||||||
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons';
|
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons';
|
||||||
@@ -33,104 +33,90 @@ const EditableTable: React.FC<EditableTableProps> = ({
|
|||||||
const [rows, setRows] = useState<TableRow[]>([]);
|
const [rows, setRows] = useState<TableRow[]>([]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
console.log('EditableTable value', value)
|
|
||||||
if (Array.isArray(value)) {
|
if (Array.isArray(value)) {
|
||||||
setRows([...value])
|
setRows([...value])
|
||||||
} else if (value && Object.keys(value).length > 0) {
|
} else if (value && Object.keys(value).length > 0) {
|
||||||
// Only update if rows are empty or significantly different
|
setRows(Object.entries(value).map(([key, val], index) => ({
|
||||||
const valueEntries = Object.entries(value)
|
key: index.toString(),
|
||||||
if (rows.length === 0 || rows.length !== valueEntries.length) {
|
name: key || '',
|
||||||
setRows(valueEntries.map(([key, val], index) => {
|
value: val || '',
|
||||||
console.log('val', val)
|
type: typeOptions.length > 0 ? typeOptions[0].value : undefined
|
||||||
return {
|
})))
|
||||||
key: index.toString(),
|
|
||||||
name: key || '',
|
|
||||||
value: val || '',
|
|
||||||
type: typeOptions.length > 0 ? typeOptions[0].value : undefined
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
setRows([])
|
setRows([])
|
||||||
}
|
}
|
||||||
}, [JSON.stringify(value), typeOptions.length])
|
}, [value, typeOptions])
|
||||||
|
|
||||||
const handleChange = (key: string, field: 'name' | 'value' | 'type', val: string) => {
|
const handleChange = (key: string, field: 'name' | 'value' | 'type', val: string) => {
|
||||||
const newRows = [...rows.map(row =>
|
const newRows = rows.map(row =>
|
||||||
row.key === key ? { ...row, [field]: val } : row
|
row.key === key ? { ...row, [field]: val } : row
|
||||||
)];
|
);
|
||||||
|
|
||||||
setRows(newRows);
|
setRows(newRows);
|
||||||
onChange?.(newRows);
|
onChange?.(newRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleAdd = () => {
|
const handleAdd = () => {
|
||||||
const newKey = Date.now().toString();
|
const newRow: TableRow = {
|
||||||
if (typeOptions.length) {
|
key: Date.now().toString(),
|
||||||
setRows([...rows, { key: newKey, name: '', value: '', type: typeOptions[0].value }]);
|
name: '',
|
||||||
} else {
|
value: '',
|
||||||
setRows([...rows, { key: newKey, name: '', value: '' }]);
|
...(typeOptions.length > 0 && { type: typeOptions[0].value })
|
||||||
}
|
};
|
||||||
|
const newRows = [...rows, newRow];
|
||||||
|
setRows(newRows);
|
||||||
|
onChange?.(newRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDelete = (key: string, index: number) => {
|
const handleDelete = (key: string) => {
|
||||||
console.log('index', index)
|
const newRows = rows.filter(row => row.key !== key);
|
||||||
|
setRows(newRows);
|
||||||
if (rows.length === 1) {
|
onChange?.(newRows);
|
||||||
setRows([]);
|
|
||||||
onChange?.([]);
|
|
||||||
} else {
|
|
||||||
const newRows = rows.filter(row => row.key !== key);
|
|
||||||
setRows(newRows);
|
|
||||||
onChange?.(newRows);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const columns = typeOptions?.length > 0 ? [
|
const columns = useMemo(() => {
|
||||||
{
|
const baseColumns = [
|
||||||
title: t('workflow.config.name'),
|
{
|
||||||
dataIndex: 'name',
|
title: typeOptions.length > 0 ? t('workflow.config.name') : '键',
|
||||||
width: '45%',
|
dataIndex: 'name',
|
||||||
render: (text: string, record: TableRow) => (
|
width: typeOptions.length > 0 ? '35%' : '45%',
|
||||||
<Editor
|
render: (text: string, record: TableRow) => (
|
||||||
options={options}
|
<Editor
|
||||||
value={text}
|
options={options}
|
||||||
height={32}
|
value={text}
|
||||||
variant="outlined"
|
height={32}
|
||||||
onChange={(value) => handleChange(record.key, 'name', value)}
|
variant="outlined"
|
||||||
/>
|
onChange={(value) => handleChange(record.key, 'name', value || '')}
|
||||||
),
|
/>
|
||||||
},
|
),
|
||||||
{
|
}
|
||||||
title: t('workflow.config.type'),
|
];
|
||||||
dataIndex: 'type',
|
|
||||||
width: '20%',
|
if (typeOptions.length > 0) {
|
||||||
render: (text: string, record: TableRow) => (
|
baseColumns.push({
|
||||||
<Select
|
title: t('workflow.config.type'),
|
||||||
value={text}
|
dataIndex: 'type',
|
||||||
options={typeOptions}
|
width: '20%',
|
||||||
onChange={(value) => {
|
render: (text: string, record: TableRow) => (
|
||||||
console.log('value record', value)
|
<Select
|
||||||
handleChange(record.key, 'type', value)
|
value={text}
|
||||||
}}
|
options={typeOptions}
|
||||||
/>
|
onChange={(value) => handleChange(record.key, 'type', value)}
|
||||||
),
|
/>
|
||||||
},
|
),
|
||||||
{
|
});
|
||||||
title: t('workflow.config.value'),
|
}
|
||||||
|
|
||||||
|
baseColumns.push({
|
||||||
|
title: typeOptions.length > 0 ? t('workflow.config.value') : '值',
|
||||||
dataIndex: 'value',
|
dataIndex: 'value',
|
||||||
width: '45%',
|
width: typeOptions.length > 0 ? '35%' : '45%',
|
||||||
render: (text: string, record: TableRow) => {
|
render: (text: string, record: TableRow) => {
|
||||||
if (record.type === 'file') {
|
if (record.type === 'file') {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<VariableSelect
|
<VariableSelect
|
||||||
options={options}
|
options={options}
|
||||||
value={text}
|
value={text}
|
||||||
onChange={(value) => {
|
onChange={(value) => handleChange(record.key, 'value', value || '')}
|
||||||
console.log('value record', value)
|
|
||||||
handleChange(record.key, 'value', value)
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -140,78 +126,41 @@ const EditableTable: React.FC<EditableTableProps> = ({
|
|||||||
value={text}
|
value={text}
|
||||||
height={32}
|
height={32}
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
onChange={(value) => {
|
onChange={(value) => handleChange(record.key, 'value', value || '')}
|
||||||
console.log('value record', value)
|
|
||||||
handleChange(record.key, 'value', value)
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
},
|
});
|
||||||
{
|
|
||||||
|
baseColumns.push({
|
||||||
title: '',
|
title: '',
|
||||||
|
dataIndex: 'actions',
|
||||||
width: '10%',
|
width: '10%',
|
||||||
render: (_: any, record: TableRow, index: number) => (
|
render: (_: any, record: TableRow) => (
|
||||||
<Button
|
<Button
|
||||||
type="text"
|
type="text"
|
||||||
icon={<DeleteOutlined />}
|
icon={<DeleteOutlined />}
|
||||||
onClick={() => handleDelete(record.key, index)}
|
onClick={() => handleDelete(record.key)}
|
||||||
/>
|
/>
|
||||||
),
|
),
|
||||||
},
|
});
|
||||||
] : [
|
|
||||||
{
|
return baseColumns;
|
||||||
title: '键',
|
}, [typeOptions, options, t]);
|
||||||
dataIndex: 'name',
|
|
||||||
width: '45%',
|
|
||||||
render: (text: string, record: TableRow) => (
|
|
||||||
<Editor
|
|
||||||
options={options}
|
|
||||||
value={text}
|
|
||||||
height={32}
|
|
||||||
variant="outlined"
|
|
||||||
onChange={(value) => handleChange(record.key, 'name', value)}
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
title: '值',
|
|
||||||
dataIndex: 'value',
|
|
||||||
width: '45%',
|
|
||||||
render: (text: string, record: TableRow) => (
|
|
||||||
<Editor
|
|
||||||
options={options}
|
|
||||||
value={text}
|
|
||||||
height={32}
|
|
||||||
variant="outlined"
|
|
||||||
onChange={(value) => handleChange(record.key, 'value', value)}
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
title: '',
|
|
||||||
width: '10%',
|
|
||||||
render: (_: any, record: TableRow, index: number) => (
|
|
||||||
<Button
|
|
||||||
type="text"
|
|
||||||
icon={<DeleteOutlined />}
|
|
||||||
onClick={() => handleDelete(record.key, index)}
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="rb:mb-4">
|
<div className="rb:mb-4">
|
||||||
{title && <div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
|
{title && (
|
||||||
<div className="rb:font-medium">{title}</div>
|
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
|
||||||
<Button
|
<div className="rb:font-medium">{title}</div>
|
||||||
type="text"
|
<Button
|
||||||
icon={<PlusOutlined />}
|
type="text"
|
||||||
onClick={handleAdd}
|
icon={<PlusOutlined />}
|
||||||
size="small"
|
onClick={handleAdd}
|
||||||
/>
|
size="small"
|
||||||
</div>}
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<Table
|
<Table
|
||||||
columns={columns}
|
columns={columns}
|
||||||
dataSource={rows}
|
dataSource={rows}
|
||||||
@@ -220,11 +169,11 @@ const EditableTable: React.FC<EditableTableProps> = ({
|
|||||||
locale={{ emptyText: <Empty size={88} /> }}
|
locale={{ emptyText: <Empty size={88} /> }}
|
||||||
scroll={{ x: 'max-content' }}
|
scroll={{ x: 'max-content' }}
|
||||||
/>
|
/>
|
||||||
{!title &&
|
{!title && (
|
||||||
<Button type="dashed" onClick={handleAdd} block className='rb:mt-1'>
|
<Button type="dashed" onClick={handleAdd} block className='rb:mt-1'>
|
||||||
+{t('common.add')}
|
+{t('common.add')}
|
||||||
</Button>
|
</Button>
|
||||||
}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { type FC, useEffect, useRef } from "react";
|
import { type FC, useRef } from "react";
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input, Slider } from 'antd'
|
import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input } from 'antd'
|
||||||
import Editor from '../../Editor'
|
import Editor from '../../Editor'
|
||||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
import AuthConfigModal from './AuthConfigModal'
|
import AuthConfigModal from './AuthConfigModal'
|
||||||
|
|||||||
@@ -128,29 +128,32 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
|
|||||||
<List
|
<List
|
||||||
grid={{ gutter: 12, column: 1 }}
|
grid={{ gutter: 12, column: 1 }}
|
||||||
dataSource={knowledgeList}
|
dataSource={knowledgeList}
|
||||||
renderItem={(item) => (
|
renderItem={(item) => {
|
||||||
<List.Item>
|
if (!item.id) return null
|
||||||
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
|
return (
|
||||||
<div className="rb:font-medium rb:leading-4">
|
<List.Item>
|
||||||
{item.name}
|
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
|
||||||
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
|
<div className="rb:font-medium rb:leading-4">
|
||||||
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
|
{item.name}
|
||||||
</Tag>
|
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
|
||||||
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
|
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
|
||||||
|
</Tag>
|
||||||
|
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
|
||||||
|
</div>
|
||||||
|
<Space size={12}>
|
||||||
|
<div
|
||||||
|
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
|
||||||
|
onClick={() => handleEditKnowledge(item)}
|
||||||
|
></div>
|
||||||
|
<div
|
||||||
|
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
|
||||||
|
onClick={() => handleDeleteKnowledge(item.id)}
|
||||||
|
></div>
|
||||||
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
<Space size={12}>
|
</List.Item>
|
||||||
<div
|
)
|
||||||
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
|
}}
|
||||||
onClick={() => handleEditKnowledge(item)}
|
|
||||||
></div>
|
|
||||||
<div
|
|
||||||
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
|
|
||||||
onClick={() => handleDeleteKnowledge(item.id)}
|
|
||||||
></div>
|
|
||||||
</Space>
|
|
||||||
</div>
|
|
||||||
</List.Item>
|
|
||||||
)}
|
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
{/* 全局设置 */}
|
{/* 全局设置 */}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { MinusCircleOutlined } from '@ant-design/icons';
|
import { MinusCircleOutlined } from '@ant-design/icons';
|
||||||
import { Button, Form, Input, Space } from 'antd';
|
import { Button, Form, Input, Space, Row, Col } from 'antd';
|
||||||
|
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||||
|
import VariableSelect from '../VariableSelect'
|
||||||
|
|
||||||
interface MappingListProps {
|
interface MappingListProps {
|
||||||
name: string;
|
name: string;
|
||||||
|
options: Suggestion[];
|
||||||
}
|
}
|
||||||
const MappingList: React.FC<MappingListProps> = ({ name }) => {
|
const MappingList: React.FC<MappingListProps> = ({ name, options }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -14,23 +17,33 @@ const MappingList: React.FC<MappingListProps> = ({ name }) => {
|
|||||||
{(fields, { add, remove }) => (
|
{(fields, { add, remove }) => (
|
||||||
<>
|
<>
|
||||||
{fields.map(({ key, name, ...restField }) => (
|
{fields.map(({ key, name, ...restField }) => (
|
||||||
<Space key={key} style={{ display: 'flex', marginBottom: 8 }} align="baseline">
|
<Row gutter={12} className="rb:mb-2">
|
||||||
<Form.Item
|
<Col span={10}>
|
||||||
{...restField}
|
<Form.Item
|
||||||
name={[name, 'name']}
|
{...restField}
|
||||||
noStyle
|
name={[name, 'name']}
|
||||||
>
|
noStyle
|
||||||
<Input placeholder={t('common.pleaseEnter')} />
|
>
|
||||||
</Form.Item>
|
<Input placeholder={t('common.pleaseEnter')} />
|
||||||
<Form.Item
|
</Form.Item>
|
||||||
{...restField}
|
</Col>
|
||||||
name={[name, 'value']}
|
<Col span={12}>
|
||||||
noStyle
|
<Form.Item
|
||||||
>
|
{...restField}
|
||||||
<Input placeholder={t('common.pleaseEnter')} />
|
name={[name, 'value']}
|
||||||
</Form.Item>
|
noStyle
|
||||||
<MinusCircleOutlined onClick={() => remove(name)} />
|
>
|
||||||
</Space>
|
<VariableSelect
|
||||||
|
placeholder={t('common.pleaseSelect')}
|
||||||
|
options={options}
|
||||||
|
popupMatchSelectWidth={false}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
</Col>
|
||||||
|
<Col span={2}>
|
||||||
|
<MinusCircleOutlined onClick={() => remove(name)} />
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
))}
|
))}
|
||||||
<Form.Item>
|
<Form.Item>
|
||||||
<Button type="dashed" onClick={() => add()} block>
|
<Button type="dashed" onClick={() => add()} block>
|
||||||
|
|||||||
@@ -29,15 +29,15 @@ const MessageEditor: FC<TextareaProps> = ({
|
|||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const form = Form.useFormInstance();
|
const form = Form.useFormInstance();
|
||||||
const values = form.getFieldsValue()
|
const values = Form.useWatch([], form);
|
||||||
|
|
||||||
// 检查是否已经使用了context变量,将已使用的context设置为disabled
|
// 检查是否已经使用了context变量,将已使用的context设置为disabled
|
||||||
const processedOptions = useMemo(() => {
|
const processedOptions = useMemo(() => {
|
||||||
if (!isArray || !values[parentName]) return options;
|
if (!isArray || !values?.[parentName]) return options;
|
||||||
|
|
||||||
// 获取所有消息内容
|
// 获取所有消息内容
|
||||||
const allContents = values[parentName]
|
const allContents = values[parentName]
|
||||||
.map((msg: any) => msg.content || '')
|
.map((msg: any) => msg?.content || '')
|
||||||
.join(' ');
|
.join(' ');
|
||||||
|
|
||||||
// 将已使用的context变量标记为disabled
|
// 将已使用的context变量标记为disabled
|
||||||
@@ -50,83 +50,74 @@ const MessageEditor: FC<TextareaProps> = ({
|
|||||||
}, [options, values, parentName, isArray]);
|
}, [options, values, parentName, isArray]);
|
||||||
|
|
||||||
const handleAdd = (add: FormListOperation['add']) => {
|
const handleAdd = (add: FormListOperation['add']) => {
|
||||||
const list = values[parentName];
|
const list = values?.[parentName] || [];
|
||||||
const lastRole = list[list.length - 1].role
|
const lastRole = list.length > 0 ? list[list.length - 1]?.role : 'ASSISTANT';
|
||||||
|
|
||||||
add({
|
add({
|
||||||
role: lastRole === 'USER' ? 'ASSISTANT' : 'USER',
|
role: lastRole === 'USER' ? 'ASSISTANT' : 'USER',
|
||||||
content: undefined
|
content: ''
|
||||||
})
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!isArray) {
|
||||||
|
return (
|
||||||
|
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
||||||
|
<Row>
|
||||||
|
<Col span={12}>
|
||||||
|
{title ?? t('workflow.answerDesc')}
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
<Form.Item name={parentName} noStyle>
|
||||||
|
<Editor placeholder={placeholder} options={processedOptions} />
|
||||||
|
</Form.Item>
|
||||||
|
</Space>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<Form.List name={parentName}>
|
||||||
{isArray
|
{(fields, { add, remove }) => (
|
||||||
? <Form.List name={parentName}>
|
<Space size={12} direction="vertical" className="rb:w-full">
|
||||||
{(fields, { add, remove }) => (
|
{fields.map(({ key, name, ...restField }) => {
|
||||||
<Space size={12} direction="vertical" className="rb:w-full">
|
const currentRole = (values?.[parentName]?.[name]?.role || 'USER').toUpperCase();
|
||||||
{fields.map(({ key, name, ...restField }) => {
|
|
||||||
const currentRole = (values[parentName]?.[key].role || 'USER').toUpperCase()
|
return (
|
||||||
|
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
||||||
return (
|
<Row>
|
||||||
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
<Col span={12}>
|
||||||
<Row>
|
<Form.Item {...restField} name={[name, 'role']} noStyle>
|
||||||
<Col span={12}>
|
{currentRole === 'SYSTEM' ? (
|
||||||
<Form.Item
|
<Input disabled />
|
||||||
{...restField}
|
) : (
|
||||||
name={[name, 'role']}
|
<Select
|
||||||
noStyle
|
options={roleOptions}
|
||||||
>
|
disabled={currentRole === 'SYSTEM'}
|
||||||
{currentRole === 'SYSTEM'
|
/>
|
||||||
? <Input disabled />
|
)}
|
||||||
:
|
|
||||||
<Select
|
|
||||||
options={roleOptions}
|
|
||||||
disabled={currentRole === 'SYSTEM'}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
</Form.Item>
|
|
||||||
</Col>
|
|
||||||
{currentRole !== 'SYSTEM' && <Col span={12}>
|
|
||||||
<div className="rb:h-full rb:flex rb:justify-end rb:items-center">
|
|
||||||
<MinusCircleOutlined onClick={() => remove(name)} />
|
|
||||||
</div>
|
|
||||||
</Col>}
|
|
||||||
</Row>
|
|
||||||
<Form.Item
|
|
||||||
{...restField}
|
|
||||||
name={[name, 'content']}
|
|
||||||
noStyle
|
|
||||||
>
|
|
||||||
<Editor placeholder={placeholder} options={processedOptions} />
|
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Space>
|
</Col>
|
||||||
)
|
{currentRole !== 'SYSTEM' && (
|
||||||
})}
|
<Col span={12}>
|
||||||
<Form.Item>
|
<div className="rb:h-full rb:flex rb:justify-end rb:items-center">
|
||||||
<Button type="dashed" onClick={() => handleAdd(add)} block>
|
<MinusCircleOutlined onClick={() => remove(name)} />
|
||||||
+{t('workflow.addMessage')}
|
</div>
|
||||||
</Button>
|
</Col>
|
||||||
</Form.Item>
|
)}
|
||||||
</Space >
|
</Row>
|
||||||
)}
|
<Form.Item {...restField} name={[name, 'content']} noStyle>
|
||||||
</Form.List>
|
<Editor placeholder={placeholder} options={processedOptions} />
|
||||||
:
|
</Form.Item>
|
||||||
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
</Space>
|
||||||
<Row>
|
);
|
||||||
<Col span={12}>
|
})}
|
||||||
{title ?? t('workflow.answerDesc')}
|
<Form.Item>
|
||||||
</Col>
|
<Button type="dashed" onClick={() => handleAdd(add)} block>
|
||||||
</Row>
|
+{t('workflow.addMessage')}
|
||||||
<Form.Item
|
</Button>
|
||||||
name={parentName}
|
|
||||||
noStyle
|
|
||||||
>
|
|
||||||
<Editor placeholder={placeholder} options={processedOptions} />
|
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</Space>
|
</Space>
|
||||||
}
|
)}
|
||||||
</div>
|
</Form.List>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -85,15 +85,6 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
const { id, knowledge_retrieval, group, group_names, ...rest } = values
|
const { id, knowledge_retrieval, group, group_names, ...rest } = values
|
||||||
const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {}
|
const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {}
|
||||||
|
|
||||||
let groupNames: Record<string, string[]> | string[] = {}
|
|
||||||
|
|
||||||
if (group && group_names?.length) {
|
|
||||||
group_names.forEach(vo => {
|
|
||||||
(groupNames as Record<string, string[]>)[vo.key] = vo.value
|
|
||||||
})
|
|
||||||
} else if (!group) {
|
|
||||||
groupNames = group_names?.[0]?.value || []
|
|
||||||
}
|
|
||||||
let allRest = {
|
let allRest = {
|
||||||
...rest,
|
...rest,
|
||||||
...restKnowledgeConfig,
|
...restKnowledgeConfig,
|
||||||
@@ -107,7 +98,14 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
|
|
||||||
Object.keys(values).forEach(key => {
|
Object.keys(values).forEach(key => {
|
||||||
if (selectedNode.data?.config?.[key]) {
|
if (selectedNode.data?.config?.[key]) {
|
||||||
selectedNode.data.config[key].defaultValue = values[key]
|
// Create a deep copy to avoid reference sharing between nodes
|
||||||
|
if (!selectedNode.data.config[key]) {
|
||||||
|
selectedNode.data.config[key] = {};
|
||||||
|
}
|
||||||
|
selectedNode.data.config[key] = {
|
||||||
|
...selectedNode.data.config[key],
|
||||||
|
defaultValue: values[key]
|
||||||
|
};
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -194,7 +192,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
|
|
||||||
const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id);
|
const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id);
|
||||||
const childNodeIds = getChildNodes(selectedNode.id);
|
const childNodeIds = getChildNodes(selectedNode.id);
|
||||||
console.log('childNodeIds', childNodeIds)
|
console.log('childNodeIds', selectedNode, childNodeIds)
|
||||||
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
|
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
|
||||||
|
|
||||||
allRelevantNodeIds.forEach(nodeId => {
|
allRelevantNodeIds.forEach(nodeId => {
|
||||||
@@ -219,7 +217,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
label: variable.name,
|
label: variable.name,
|
||||||
type: 'variable',
|
type: 'variable',
|
||||||
dataType: variable.type,
|
dataType: variable.type,
|
||||||
value: `{{${nodeId}.${variable.name}}}`,
|
value: `${node.getData().id}.${variable.name}`,
|
||||||
nodeData: nodeData,
|
nodeData: nodeData,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -249,7 +247,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
label: 'output',
|
label: 'output',
|
||||||
type: 'variable',
|
type: 'variable',
|
||||||
dataType: 'String',
|
dataType: 'String',
|
||||||
value: `${nodeId}.output`,
|
value: `${node.getData().id}.output`,
|
||||||
nodeData: nodeData,
|
nodeData: nodeData,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -263,7 +261,104 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
label: 'message',
|
label: 'message',
|
||||||
type: 'variable',
|
type: 'variable',
|
||||||
dataType: 'array[object]',
|
dataType: 'array[object]',
|
||||||
value: `${nodeId}.message`,
|
value: `${node.getData().id}.message`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'parameter-extractor':
|
||||||
|
const successKey = `${nodeId}___is_success`;
|
||||||
|
const reasonKey = `${nodeId}___reason`;
|
||||||
|
if (!addedKeys.has(successKey)) {
|
||||||
|
addedKeys.add(successKey);
|
||||||
|
variableList.push({
|
||||||
|
key: successKey,
|
||||||
|
label: '__is_success',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'number',
|
||||||
|
value: `${node.getData().id}.__is_success`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (!addedKeys.has(reasonKey)) {
|
||||||
|
addedKeys.add(reasonKey);
|
||||||
|
variableList.push({
|
||||||
|
key: reasonKey,
|
||||||
|
label: '__reason',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'string',
|
||||||
|
value: `${node.getData().id}.__reason`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Add params variables
|
||||||
|
const paramsList = nodeData.config?.params?.defaultValue || [];
|
||||||
|
paramsList.forEach((param: any) => {
|
||||||
|
if (!param || !param?.name) return;
|
||||||
|
const paramKey = `${nodeId}_${param.name}`;
|
||||||
|
if (!addedKeys.has(paramKey)) {
|
||||||
|
addedKeys.add(paramKey);
|
||||||
|
variableList.push({
|
||||||
|
key: paramKey,
|
||||||
|
label: param.name,
|
||||||
|
type: 'variable',
|
||||||
|
dataType: param.type || 'string',
|
||||||
|
value: `${node.getData().id}.${param.name}`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
break
|
||||||
|
case 'var-aggregator':
|
||||||
|
const varAggregatorKey = `${nodeId}_output`;
|
||||||
|
if (!addedKeys.has(varAggregatorKey)) {
|
||||||
|
addedKeys.add(varAggregatorKey);
|
||||||
|
variableList.push({
|
||||||
|
key: varAggregatorKey,
|
||||||
|
label: 'output',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'string',
|
||||||
|
value: `${node.getData().id}.output`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'http-request':
|
||||||
|
const httpBodyKey = `${nodeId}_body`;
|
||||||
|
const httpStatusKey = `${nodeId}_status_code`;
|
||||||
|
if (!addedKeys.has(httpBodyKey)) {
|
||||||
|
addedKeys.add(httpBodyKey);
|
||||||
|
variableList.push({
|
||||||
|
key: httpBodyKey,
|
||||||
|
label: 'body',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'string',
|
||||||
|
value: `${node.getData().id}.body`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (!addedKeys.has(httpStatusKey)) {
|
||||||
|
addedKeys.add(httpStatusKey);
|
||||||
|
variableList.push({
|
||||||
|
key: httpStatusKey,
|
||||||
|
label: 'status_code',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'number',
|
||||||
|
value: `${node.getData().id}.status_code`,
|
||||||
|
nodeData: nodeData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'jinja-render':
|
||||||
|
const jinjaOutputKey = `${nodeId}_output`;
|
||||||
|
if (!addedKeys.has(jinjaOutputKey)) {
|
||||||
|
addedKeys.add(jinjaOutputKey);
|
||||||
|
variableList.push({
|
||||||
|
key: jinjaOutputKey,
|
||||||
|
label: 'output',
|
||||||
|
type: 'variable',
|
||||||
|
dataType: 'string',
|
||||||
|
value: `${node.getData().id}.output`,
|
||||||
nodeData: nodeData,
|
nodeData: nodeData,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -283,7 +378,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
label: variable.name,
|
label: variable.name,
|
||||||
type: 'variable',
|
type: 'variable',
|
||||||
dataType: variable.type,
|
dataType: variable.type,
|
||||||
value: `conversation.${variable.name}`,
|
value: `conv.${variable.name}`,
|
||||||
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
|
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
|
||||||
group: 'CONVERSATION'
|
group: 'CONVERSATION'
|
||||||
});
|
});
|
||||||
@@ -387,7 +482,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
label: 'context',
|
label: 'context',
|
||||||
type: 'variable',
|
type: 'variable',
|
||||||
dataType: 'String',
|
dataType: 'String',
|
||||||
value: `{{context}}`,
|
value: `context`,
|
||||||
nodeData: selectedNode.getData(),
|
nodeData: selectedNode.getData(),
|
||||||
isContext: true,
|
isContext: true,
|
||||||
});
|
});
|
||||||
@@ -476,7 +571,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
<Form.Item key={key} name={key}
|
<Form.Item key={key} name={key}
|
||||||
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
|
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
|
||||||
>
|
>
|
||||||
<MappingList name={key} />
|
<MappingList name={key} options={variableList} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -583,7 +678,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
? <Input.TextArea placeholder={t('common.pleaseEnter')} />
|
? <Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||||
: config.type === 'select'
|
: config.type === 'select'
|
||||||
? <Select
|
? <Select
|
||||||
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
|
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
|
||||||
placeholder={t('common.pleaseSelect')}
|
placeholder={t('common.pleaseSelect')}
|
||||||
/>
|
/>
|
||||||
: config.type === 'inputNumber'
|
: config.type === 'inputNumber'
|
||||||
@@ -635,7 +730,7 @@ const Properties: FC<PropertiesProps> = ({
|
|||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
: config.type === 'switch'
|
: config.type === 'switch'
|
||||||
? <Switch />
|
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_names', []) } : undefined} />
|
||||||
: config.type === 'categoryList'
|
: config.type === 'categoryList'
|
||||||
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
|
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
|
||||||
: config.type === 'conditionList'
|
: config.type === 'conditionList'
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ import processEvolutionIcon from '@/assets/images/workflow/process_evolution.png
|
|||||||
import questionClassifierIcon from '@/assets/images/workflow/question-classifier.png'
|
import questionClassifierIcon from '@/assets/images/workflow/question-classifier.png'
|
||||||
import breakIcon from '@/assets/images/workflow/break.png'
|
import breakIcon from '@/assets/images/workflow/break.png'
|
||||||
import assignerIcon from '@/assets/images/workflow/assigner.png'
|
import assignerIcon from '@/assets/images/workflow/assigner.png'
|
||||||
|
import memoryReadIcon from '@/assets/images/workflow/memory-read.png'
|
||||||
|
import memoryWriteIcon from '@/assets/images/workflow/memory-write.png'
|
||||||
|
|
||||||
import { memoryConfigListUrl } from '@/api/memory'
|
import { memoryConfigListUrl } from '@/api/memory'
|
||||||
|
|
||||||
import { getModelListUrl } from '@/api/models'
|
import { getModelListUrl } from '@/api/models'
|
||||||
@@ -159,6 +162,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
},
|
},
|
||||||
text: {
|
text: {
|
||||||
type: 'variableList',
|
type: 'variableList',
|
||||||
|
filterLoopIterationVars: true
|
||||||
},
|
},
|
||||||
params: {
|
params: {
|
||||||
type: 'paramList',
|
type: 'paramList',
|
||||||
@@ -174,8 +178,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
{
|
{
|
||||||
category: "cognitiveUpgrading",
|
category: "cognitiveUpgrading",
|
||||||
nodes: [
|
nodes: [
|
||||||
{
|
{ type: "memory-read", icon: memoryReadIcon,
|
||||||
type: "memory-read", icon: memoryEnhancementIcon,
|
|
||||||
config: {
|
config: {
|
||||||
message: {
|
message: {
|
||||||
type: 'messageEditor',
|
type: 'messageEditor',
|
||||||
@@ -198,7 +201,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ type: "memory-write", icon: memoryEnhancementIcon,
|
{ type: "memory-write", icon: memoryWriteIcon,
|
||||||
config: {
|
config: {
|
||||||
message: {
|
message: {
|
||||||
type: 'messageEditor',
|
type: 'messageEditor',
|
||||||
@@ -272,6 +275,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
},
|
},
|
||||||
parallel: {
|
parallel: {
|
||||||
type: 'switch',
|
type: 'switch',
|
||||||
|
defaultValue: false
|
||||||
},
|
},
|
||||||
parallel_count: {
|
parallel_count: {
|
||||||
type: 'slider',
|
type: 'slider',
|
||||||
@@ -284,6 +288,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
},
|
},
|
||||||
flatten: { // 扁平化输出
|
flatten: { // 扁平化输出
|
||||||
type: 'switch',
|
type: 'switch',
|
||||||
|
defaultValue: false
|
||||||
},
|
},
|
||||||
output: {
|
output: {
|
||||||
type: 'variableList',
|
type: 'variableList',
|
||||||
@@ -304,6 +309,13 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
expressions: []
|
expressions: []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
max_loop: {
|
||||||
|
type: 'slider',
|
||||||
|
min: 1,
|
||||||
|
max: 100,
|
||||||
|
step: 1,
|
||||||
|
defaultValue: 10
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ type: "cycle-start", icon: loopIcon },
|
{ type: "cycle-start", icon: loopIcon },
|
||||||
@@ -317,7 +329,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
},
|
},
|
||||||
group_names: {
|
group_names: {
|
||||||
type: 'groupVariableList',
|
type: 'groupVariableList',
|
||||||
defaultValue: [{ key: 'Group1', value: []}]
|
defaultValue: [],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -382,7 +394,8 @@ export const nodeLibrary: NodeLibrary[] = [
|
|||||||
defaultValue: {}
|
defaultValue: {}
|
||||||
},
|
},
|
||||||
retry: {
|
retry: {
|
||||||
type: 'define',
|
type: 'switch',
|
||||||
|
defaultValue: false
|
||||||
},
|
},
|
||||||
error_handle: {
|
error_handle: {
|
||||||
type: 'define',
|
type: 'define',
|
||||||
|
|||||||
@@ -94,9 +94,7 @@ export const useWorkflowGraph = ({
|
|||||||
const { group_names, group } = config
|
const { group_names, group } = config
|
||||||
nodeLibraryConfig.config[key].defaultValue = group
|
nodeLibraryConfig.config[key].defaultValue = group
|
||||||
? Object.entries(group_names as Record<string, any>).map(([key, value]) => ({ key, value }))
|
? Object.entries(group_names as Record<string, any>).map(([key, value]) => ({ key, value }))
|
||||||
: [{ key: 'Group1', value: group_names }]
|
: group_names
|
||||||
|
|
||||||
console.log('group_names', nodeLibraryConfig.config)
|
|
||||||
} else if (nodeLibraryConfig.config && nodeLibraryConfig.config[key] && config[key]) {
|
} else if (nodeLibraryConfig.config && nodeLibraryConfig.config[key] && config[key]) {
|
||||||
nodeLibraryConfig.config[key].defaultValue = config[key]
|
nodeLibraryConfig.config[key].defaultValue = config[key]
|
||||||
}
|
}
|
||||||
@@ -832,7 +830,7 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
// 创建干净的节点数据,只保留必要的字段
|
// 创建干净的节点数据,只保留必要的字段
|
||||||
const cleanNodeData = {
|
const cleanNodeData = {
|
||||||
id: `${dragData.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
id: `${dragData.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||||
name: t(`workflow.${dragData.type}`),
|
name: t(`workflow.${dragData.type}`),
|
||||||
...nodeLibraryConfig
|
...nodeLibraryConfig
|
||||||
};
|
};
|
||||||
@@ -842,6 +840,7 @@ export const useWorkflowGraph = ({
|
|||||||
...graphNodeLibrary[dragData.type],
|
...graphNodeLibrary[dragData.type],
|
||||||
x: point.x - 150,
|
x: point.x - 150,
|
||||||
y: point.y - 100,
|
y: point.y - 100,
|
||||||
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData, isGroup: true },
|
data: { ...cleanNodeData, isGroup: true },
|
||||||
});
|
});
|
||||||
} else if (dragData.type === 'if-else') {
|
} else if (dragData.type === 'if-else') {
|
||||||
@@ -850,6 +849,7 @@ export const useWorkflowGraph = ({
|
|||||||
...graphNodeLibrary[dragData.type],
|
...graphNodeLibrary[dragData.type],
|
||||||
x: point.x - 100,
|
x: point.x - 100,
|
||||||
y: point.y - 60,
|
y: point.y - 60,
|
||||||
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData },
|
data: { ...cleanNodeData },
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
@@ -858,6 +858,7 @@ export const useWorkflowGraph = ({
|
|||||||
...(graphNodeLibrary[dragData.type] || graphNodeLibrary.default),
|
...(graphNodeLibrary[dragData.type] || graphNodeLibrary.default),
|
||||||
x: point.x - 60,
|
x: point.x - 60,
|
||||||
y: point.y - 20,
|
y: point.y - 20,
|
||||||
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData },
|
data: { ...cleanNodeData },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -881,7 +882,15 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
if (data.config) {
|
if (data.config) {
|
||||||
Object.keys(data.config).forEach(key => {
|
Object.keys(data.config).forEach(key => {
|
||||||
if (data.config[key] && 'defaultValue' in data.config[key] && key !== 'knowledge_retrieval') {
|
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_names') {
|
||||||
|
let group_names = data.config.group.defaultValue ? {} : data.config[key].defaultValue
|
||||||
|
if (data.config.group.defaultValue) {
|
||||||
|
data.config[key].defaultValue.map((vo: any) => {
|
||||||
|
group_names[vo.key] = vo.value
|
||||||
|
})
|
||||||
|
}
|
||||||
|
itemConfig[key] = group_names
|
||||||
|
} else if (data.config[key] && 'defaultValue' in data.config[key] && key !== 'knowledge_retrieval') {
|
||||||
itemConfig[key] = data.config[key].defaultValue
|
itemConfig[key] = data.config[key].defaultValue
|
||||||
} else if (key === 'knowledge_retrieval' && data.config[key] && 'defaultValue' in data.config[key]) {
|
} else if (key === 'knowledge_retrieval' && data.config[key] && 'defaultValue' in data.config[key]) {
|
||||||
const { knowledge_bases } = data.config[key].defaultValue
|
const { knowledge_bases } = data.config[key].defaultValue
|
||||||
@@ -910,7 +919,7 @@ export const useWorkflowGraph = ({
|
|||||||
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
|
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
|
||||||
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
|
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
|
||||||
const sourcePortId = edge.getSourcePortId();
|
const sourcePortId = edge.getSourcePortId();
|
||||||
|
|
||||||
// 过滤无效连线:源节点或目标节点不存在,或者是add-node类型
|
// 过滤无效连线:源节点或目标节点不存在,或者是add-node类型
|
||||||
if (!sourceCell?.getData()?.id || !targetCell?.getData()?.id ||
|
if (!sourceCell?.getData()?.id || !targetCell?.getData()?.id ||
|
||||||
sourceCell?.getData()?.type === 'add-node' || targetCell?.getData()?.type === 'add-node') {
|
sourceCell?.getData()?.type === 'add-node' || targetCell?.getData()?.type === 'add-node') {
|
||||||
|
|||||||
Reference in New Issue
Block a user