Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
yujiangping
2026-01-06 19:27:09 +08:00
61 changed files with 2052 additions and 1375 deletions

View File

@@ -728,9 +728,23 @@ async def draft_run_compare(
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 获取 agent_cfg.model_parameters如果是 ModelParameters 对象则转为字典
agent_model_params = agent_cfg.model_parameters
if hasattr(agent_model_params, 'model_dump'):
agent_model_params = agent_model_params.model_dump()
elif not isinstance(agent_model_params, dict):
agent_model_params = {}
# 获取 model_item.model_parameters如果是 ModelParameters 对象则转为字典
item_model_params = model_item.model_parameters
if hasattr(item_model_params, 'model_dump'):
item_model_params = item_model_params.model_dump()
elif not isinstance(item_model_params, dict):
item_model_params = {}
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
**(agent_model_params or {}),
**(item_model_params or {})
}
model_configs.append({

View File

@@ -108,16 +108,23 @@ async def get_prompt_opt(
service = PromptOptimizerService(db)
async def event_generator():
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
yield "event:start\ndata: {}\n\n"
try:
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event:error\ndata: {json.dumps(
{"error": str(e)}
)}\n\n"
yield "event:end\ndata: {}\n\n"
return StreamingResponse(
event_generator(),

View File

@@ -1,4 +1,5 @@
import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
@@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -288,7 +289,7 @@ async def chat(
password = None # Token 认证不需要密码
# end_user_id = user_id
other_id = user_id
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
@@ -364,6 +365,9 @@ async def chat(
config = release.config or {}
if not config.get("sub_agents"):
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
elif app_type == AppType.WORKFLOW:
# Multi-Agent 类型:验证多 Agent 配置
pass
else:
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@@ -392,6 +396,8 @@ async def chat(
if app_type == AppType.AGENT:
# 流式返回
agent_config = agent_config_4_app_release(release)
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
@@ -424,7 +430,7 @@ async def chat(
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=payload.agent_config,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
@@ -467,6 +473,7 @@ async def chat(
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# config = workflow_config_4_app_release(release)
config = multi_agent_config_4_app_release(release)
if payload.stream:
async def event_generator():
@@ -551,8 +558,71 @@ async def chat(
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -1,4 +1,5 @@
"""App 服务接口 - 基于 API Key 认证"""
import json
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Body
@@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
@@ -226,22 +227,29 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.WORKFLOW:
# 多 Agent 流式返回
config = dict_to_workflow_config(app.current_release.config,app.id)
config = workflow_config_4_app_release(app.current_release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
):
yield event
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),
@@ -253,21 +261,32 @@ async def chat(
}
)
# 非流式返回
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode

View File

@@ -11,9 +11,9 @@ from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.user_memory_service import (
UserMemoryService,
analytics_node_statistics,
analytics_memory_types,
analytics_graph_data,
)
@@ -41,24 +41,27 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: str, # 使用 end_user_id
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""获取缓存的记忆洞察报告"""
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""
获取缓存的记忆洞察报告
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
如需生成新的洞察报告,请使用专门的生成接口。
"""
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
return success(data=result, msg="数据尚未生成")
except Exception as e:
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
@@ -66,24 +69,27 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str, # 使用 end_user_id
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""获取缓存的用户摘要"""
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""
获取缓存的用户摘要
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
"""
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
return success(data=result, msg="数据尚未生成")
except Exception as e:
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
@@ -351,7 +357,7 @@ async def update_end_user_profile(
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp)
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():

View File

@@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system.
Available functions:
- get_hot_memory_tags: Get hot memory tags by frequency
- MemoryInsight: Generate memory insight reports
- get_recent_activity_stats: Get recent activity statistics
- generate_user_summary: Generate user summary
Note: MemoryInsight and generate_user_summary have been moved to
app.services.user_memory_service for better architecture.
"""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
__all__ = [
"get_hot_memory_tags",
"MemoryInsight",
"get_recent_activity_stats",
"generate_user_summary",
]

View File

@@ -1,327 +0,0 @@
"""
This module provides the MemoryInsight class for analyzing user memory data.
MemoryInsight 是一个工具类,提供基础的数据获取和分析功能:
- get_domain_distribution(): 获取记忆领域分布
- get_active_periods(): 获取活跃时段
- get_social_connections(): 获取社交关联
业务逻辑如生成洞察报告应该在服务层user_memory_service.py中实现。
This script can be executed directly to test the memory insight generation for a test user.
"""
import asyncio
import json
import os
import sys
from collections import Counter
from datetime import datetime
# To run this script directly, we need to add the src directory to the Python path
# to resolve the inconsistent imports in other modules.
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class TagClassification(BaseModel):
"""
Represents the classification of a tag into a specific domain.
"""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
class InsightReport(BaseModel):
"""
Represents the final insight report generated by the LLM.
"""
report: str = Field(
...,
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
)
class MemoryInsight:
"""
Provides insights into user memories by analyzing various aspects of their data.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""
Calculates the distribution of memory domains based on hot tags.
"""
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
# 直接调用并等待结果
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
)
async def get_active_periods(self) -> list[int]:
"""
Identifies the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
This method checks if the time data represents real user memory timestamps
rather than auto-generated system timestamps by verifying:
1. Time data exists and is parseable
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
"""
query = f"""
MATCH (d:Dialogue)
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
continue
try:
# 尝试解析时间字符串
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
# 如果解析失败,跳过这条记录
continue
# 如果没有有效的时间数据,返回空列表
if not month_counts or valid_dates_count == 0:
return []
# 检查时间分布是否过于集中(可能是批量导入的数据)
# 如果超过80%的数据集中在1-2个月认为这是系统时间戳而非真实时间
unique_months = len(month_counts)
if unique_months <= 2:
# 只有1-2个月有数据很可能是批量导入
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
# 超过80%集中在一个月,认为是系统时间戳
return []
# 如果时间分布较为分散3个月以上认为是真实时间数据
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 2个月的情况检查是否分布均匀
if unique_months == 2:
counts = list(month_counts.values())
# 如果两个月的数据量相差不大比例在0.3-3之间认为是真实数据
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 其他情况返回空列表
return []
async def get_social_connections(self) -> dict | None:
"""
Finds the user with whom the most memories are shared.
使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。
"""
# 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆
query = f"""
MATCH (c1:Chunk {{group_id: '{self.user_id}'}})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query)
if not records or not records[0].get("other_user_id"):
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
# 使用 Chunk 的时间范围
time_range_query = f"""
MATCH (c:Chunk)
WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}']
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(time_range_query)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
async def close(self):
"""
Closes the database connection.
"""
await self.neo4j_connector.close()
async def main():
"""
Initializes and runs the memory insight analysis for a test user.
"""
# 默认从环境变量读取
test_user_id = DEFAULT_GROUP_ID
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
try:
# 使用服务层函数生成报告
from app.services.user_memory_service import analytics_memory_insight_report
result = await analytics_memory_insight_report(end_user_id=test_user_id)
report = result.get("report", "")
print("--- 记忆洞察报告 ---")
print(report)
print("---------------------")
# 将结果写入统一的 User-Dashboard.json使用全局配置路径
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["memory_insight"] = {
"group_id": test_user_id,
"report": report
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> memory_insight")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成报告时出错: {e}")
if __name__ == "__main__":
# This setup allows running the async main function
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())

View File

@@ -1,157 +0,0 @@
"""
Generate a concise "关于我" style user summary using data from Neo4j
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
Usage:
python -m analytics.user_summary --user_id <group_id>
"""
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple
# Ensure absolute imports work whether executed directly or via module
try:
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
src_path = os.path.join(project_root, 'src')
if src_path not in sys.path:
sys.path.insert(0, src_path)
if project_root not in sys.path:
sys.path.insert(0, project_root)
except Exception:
pass
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
@dataclass
class StatementRecord:
statement: str
created_at: str | None
class UserSummary:
"""Builds a textual user summary for a given user/group id."""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
records: List[StatementRecord] = []
for r in rows:
try:
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
except Exception:
continue
return records
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service
async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless
"""
生成用户摘要的便捷函数
Args:
user_id: 可选的用户ID
Returns:
用户摘要字符串
"""
# 导入服务层函数
from app.services.user_memory_service import analytics_user_summary
# 调用服务层函数
result = await analytics_user_summary(user_id)
return result.get("summary", "")
if __name__ == "__main__":
print("开始生成用户摘要…")
try:
# 直接使用 runtime.json 中的 group_id
summary = asyncio.run(generate_user_summary())
print("\n— 用户摘要 —\n")
print(summary)
# 将结果写入统一的 User-Dashboard.json
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["user_summary"] = {
"group_id": DEFAULT_GROUP_ID,
"summary": summary
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> user_summary")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成摘要失败: {e}")
print("请检查: 1) Neo4j 是否可用2) config.json 与 .env 的 LLM/Neo4j 配置是否正确3) 数据是否包含该用户的内容。")

View File

@@ -85,33 +85,21 @@ Example Output:
===End of Example===
===Reflection Process===
===Internal Quality Checks (DO NOT OUTPUT)===
After generating the profile, perform the following self-review steps:
Before generating your final output, internally verify:
1. All content is grounded in provided data (no fabrication)
2. Format follows the specified structure with correct headers
3. Tone is objective, third-person, and neutral
4. All four sections are complete and within character limits
**Step 1: Data Grounding Check**
- Verify all statements are supported by the provided entities and statements
- Ensure no fabricated or speculated information is included
- Confirm all claims can be traced back to the input data
**Step 2: Format Compliance**
- Verify each section follows the specified format with section headers
- Check character count limits for each section
- Ensure proper use of section markers (【】)
**Step 3: Tone and Style Review**
- Confirm objective third-person perspective is maintained
- Check for excessive adjectives or empty phrases
- Verify neutral and restrained tone throughout
**Step 4: Completeness Check**
- Ensure all four sections are present and complete
- Verify each section addresses its specific focus area
- Confirm the one-sentence summary effectively captures the user's essence
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
===Output Requirements===
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
**LANGUAGE REQUIREMENT:**
- The output language should ALWAYS be Chinese (Simplified)
- All section content must be in Chinese
@@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps:
- Content follows immediately after the header
- Sections are separated by blank lines
- Strictly adhere to character limits for each section
- **DO NOT include any text after the 【一句话总结】 section**
- **DO NOT output reflection steps, self-review, or verification notes**

View File

@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
looping: bool
looping: Annotated[bool, lambda x, y: x and y]
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)

View File

@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
retries -= 1
if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
raise e
except Exception as e:
raise RuntimeError(f"HTTP request node exception: {e}")
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"
raise RuntimeError("http request failed")

View File

@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
# Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _:
raise RuntimeError("Unknown retrieval type")
vector_service.reranker = self.get_reranker_model()
# TODO其他重排序方式支持
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs]
return [chunk.page_content for chunk in final_rs]

View File

@@ -1,5 +1,7 @@
"""LLM 节点配置"""
from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
context: Any = Field(
default="",
description="上下文"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [

View File

@@ -5,15 +5,17 @@ LLM 节点实现
"""
import logging
import re
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
def _render_context(self, message,state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -76,15 +85,16 @@ class LLMNode(BaseNode):
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
@@ -95,7 +105,7 @@ class LLMNode(BaseNode):
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
@@ -106,17 +116,17 @@ class LLMNode(BaseNode):
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
@@ -124,26 +134,26 @@ class LLMNode(BaseNode):
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params=extra_params
),
),
type=ModelType(model_type)
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
@@ -153,10 +163,10 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state,True)
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
@@ -164,16 +174,16 @@ class LLMNode(BaseNode):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
@@ -186,13 +196,13 @@ class LLMNode(BaseNode):
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
@@ -204,7 +214,7 @@ class LLMNode(BaseNode):
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
@@ -215,26 +225,26 @@ class LLMNode(BaseNode):
文本片段chunk或完成标记
"""
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
@@ -246,12 +256,12 @@ class LLMNode(BaseNode):
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
@@ -259,18 +269,18 @@ class LLMNode(BaseNode):
content = chunk.content
else:
content = str(chunk)
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
@@ -279,6 +289,6 @@ class LLMNode(BaseNode):
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory(
group_id=end_user_id,
message=self.typed_config.message,
message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self.typed_config.message,
message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id,
db=db,
storage_type="neo4j",

View File

@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
if category_count > 0:
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}
try:
llm = self._get_llm_instance()
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return f"CASE{category_names.index(category) + 1}"
return {
"class_name": category,
"output": f"CASE{category_names.index(category) + 1}",
}
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}

View File

@@ -1,4 +1,6 @@
from pydantic import Field
from typing import Any
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -1,5 +1,5 @@
import logging
import uuid
import re
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -9,6 +9,8 @@ from app.db import get_db_read
logger = logging.getLogger(__name__)
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class ToolNode(BaseNode):
"""工具节点"""
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
workspace_id = self.get_variable("sys.workspace_id", state)
if workspace_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
logger.error(f"节点 {self.node_id} 缺少租户ID")
return {
"success": False,
"data": "缺少租户ID"
}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, state)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:
# 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"data": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}

View File

@@ -87,10 +87,11 @@ class WorkflowValidator:
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
publish: 发布验证标识
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
@@ -114,7 +115,7 @@ class WorkflowValidator:
graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for graph in graphs:
for index, graph in enumerate(graphs):
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
variables = graph.get("variables", [])
@@ -125,10 +126,11 @@ class WorkflowValidator:
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
if index == len(graphs) - 1:
# 2. 验证 主图end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
@@ -159,15 +161,17 @@ class WorkflowValidator:
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
if publish:
# 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
@@ -288,7 +292,7 @@ class WorkflowValidator:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
if not is_valid:
return False, errors

View File

@@ -4,6 +4,8 @@ from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from app.db import Base
from app.models.multi_agent_model import PydanticType
from app.schemas import ModelParameters
class AgentConfig(Base):
@@ -17,14 +19,17 @@ class AgentConfig(Base):
# Agent 行为配置
system_prompt = Column(Text, nullable=True, comment="系统提示词")
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
# 结构化配置(直接存储 JSON
model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等")
# model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等")
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
comment="模型参数配置temperature、max_tokens等")
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
@@ -41,4 +46,4 @@ class AgentConfig(Base):
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
def __repr__(self):
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"

View File

@@ -38,6 +38,33 @@ class ToolRepository:
return result[0] if result else None
@staticmethod
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
"""
根据空间ID获取tenant_id
Args:
db: 数据库会话
workspace_id: 空间ID
Returns:
tenant_id或None
"""
from app.models.workspace_model import Workspace
tenant_id = db.query(Workspace.tenant_id).filter(
Workspace.id == workspace_id
).scalar()
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
try:
tenant_id = uuid.UUID(tenant_id)
except (ValueError, TypeError):
return None
return tenant_id
@staticmethod
def find_by_tenant(
db: Session,

View File

@@ -86,7 +86,12 @@ class AgentConfigConverter:
# 1. 解析模型参数配置
if model_parameters:
from app.schemas.app_schema import ModelParameters
result["model_parameters"] = ModelParameters(**model_parameters)
if isinstance(model_parameters, ModelParameters):
result["model_parameters"] = model_parameters
elif isinstance(model_parameters, dict):
result["model_parameters"] = ModelParameters(**model_parameters)
else:
result["model_parameters"] = ModelParameters()
# 2. 解析知识库检索配置
if knowledge_retrieval:

View File

@@ -9,15 +9,18 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -184,7 +187,7 @@ class AppChatService:
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "")
system_prompt = config.system_prompt
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
@@ -197,7 +200,7 @@ class AppChatService:
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
@@ -208,13 +211,13 @@ class AppChatService:
# 添加长期记忆工具
memory_flag = False
if memory:
memory_config = config.get("memory", {})
memory_config = config.memory
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
@@ -230,7 +233,7 @@ class AppChatService:
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
model_parameters = config.model_parameters
# 创建 LangChain Agent
agent = LangChainAgent(
@@ -479,7 +482,9 @@ class AppChatService:
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
config: WorkflowConfig,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
@@ -488,281 +493,159 @@ class AppChatService:
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
workflow_service = WorkflowService(self.db)
start_time = time.time()
config_id = None
input_data = {"message":message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
if variables is None:
variables = {}
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
)
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag = False
if memory == True:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
# 更新执行结果
if result.get("status") == "completed":
workflow_service.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 调用 Agent
result = await agent.chat(
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
self.conversation_service.save_conversation_messages(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"]
)
elapsed_time = time.time() - start_time
return {
"conversation_id": conversation_id,
"message": result["content"],
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
}
async def workflow_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
config: WorkflowConfig,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
workflow_service = WorkflowService(self.db)
input_data = {"message": message, "variables": variables,
"conversation_id": str(conversation_id)}
inconfig = workflow_service.get_workflow_config(app_id)
# 2. 创建执行记录
execution = workflow_service.create_execution(
workflow_config_id=inconfig.id,
app_id=app_id,
trigger_type="manual",
triggered_by=None,
conversation_id=conversation_id,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 流式执行工作流
try:
start_time = time.time()
config_id = None
# 更新状态为运行中
workflow_service.update_execution_status(execution.execution_id, "running")
if variables is None:
variables = {}
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag = False
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent
full_content = ""
async for chunk in agent.chat_stream(
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in workflow_service._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=user_id
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
# 直接转发 executor 的事件(已经是正确的格式)
yield event
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
}
)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
"流式聊天完成",
extra={
"conversation_id": str(conversation_id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断")
raise
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
workflow_service.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ====================
def get_app_chat_service(

View File

@@ -21,6 +21,7 @@ from app.core.exceptions import (
BusinessException,
)
from app.core.logging_config import get_business_logger
from app.core.workflow.validator import WorkflowValidator
from app.db import get_db
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
from app.models.app_model import AppStatus, AppType
@@ -31,6 +32,7 @@ from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services.agent_config_converter import AgentConfigConverter
from app.models import AppShare, Workspace
from app.services.model_service import ModelApiKeyService
from app.services.workflow_service import WorkflowService
# 获取业务日志器
logger = get_business_logger()
@@ -1225,6 +1227,26 @@ class AppService:
"orchestration_mode": multi_agent_cfg.orchestration_mode
}
)
elif app.type == AppType.WORKFLOW:
service = WorkflowService(self.db)
workflow_cfg = service.get_workflow_config(app_id)
if not workflow_cfg:
raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING)
config = {
"nodes": workflow_cfg.nodes,
"edges": workflow_cfg.edges,
"variables": workflow_cfg.variables,
"execution_config": workflow_cfg.execution_config,
"triggers": workflow_cfg.triggers
}
is_valid, errors = WorkflowValidator.validate_for_publish(config)
if not is_valid:
raise BusinessException("应用缺少有效配置,无法发布", BizCode.CONFIG_MISSING)
logger.info(
"应用发布配置准备完成"
)
now = datetime.datetime.now()
version = self._get_next_version(app_id)

View File

@@ -1293,6 +1293,7 @@ class MultiAgentOrchestrator:
conversation_id: 会话 ID
user_id: 用户 ID
Returns:
执行结果
"""

View File

@@ -231,9 +231,9 @@ class PromptOptimizerService:
if m:
prompt_index = m.start()
prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]}
yield {"content": buffer[idx:prompt_index]}
else:
yield {"type": "delta", "content": cache[idx:]}
yield {"content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
@@ -249,8 +249,8 @@ class PromptOptimizerService:
role=RoleType.ASSISTANT,
content=desc
)
yield {"type": "done", "desc": optim_result.get("desc")}
variables = self.parser_prompt_variables(optim_result.get("prompt"))
yield {"desc": optim_result.get("desc"), "variables": variables}
@staticmethod
def parser_prompt_variables(prompt: str):

View File

@@ -344,14 +344,16 @@ class ToolService:
break
if operation_param:
# 有多个操作
# 有多个操作,为每个操作生成具体参数
methods = []
for operation in operation_param.enum:
# 获取该操作的具体参数
operation_params = self._get_operation_specific_params(tool_instance, operation)
methods.append({
"method_id": f"{config.name}_{operation}",
"name": operation,
"description": f"{config.description} - {operation}",
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
"parameters": operation_params
})
return methods
else:
@@ -362,6 +364,243 @@ class ToolService:
"description": config.description,
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
}]
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
"""获取特定操作的参数列表"""
# 对于datetime_tool根据操作类型返回相关参数
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
return self._get_datetime_tool_params(operation)
# 对于json_tool根据操作类型返回相关参数
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
return self._get_json_tool_params(operation)
# 其他工具的默认处理返回除operation外的所有参数
return [{
"name": param.name,
"type": param.type.value,
"description": param.description,
"required": param.required,
"default": param.default,
"enum": param.enum,
"minimum": param.minimum,
"maximum": param.maximum,
"pattern": param.pattern
} for param in tool_instance.parameters if param.name != "operation"]
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
"""获取datetime_tool特定操作的参数"""
if operation == "now":
return [
{
"name": "to_timezone",
"type": "string",
"description": "目标时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
},
{
"name": "output_format",
"type": "string",
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
}
]
elif operation == "format":
return [
{
"name": "input_value",
"type": "string",
"description": "输入值(时间字符串或时间戳)",
"required": True
},
{
"name": "input_format",
"type": "string",
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "output_format",
"type": "string",
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
}
]
elif operation == "convert_timezone":
return [
{
"name": "input_value",
"type": "string",
"description": "输入值(时间字符串或时间戳)",
"required": True
},
{
"name": "input_format",
"type": "string",
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "output_format",
"type": "string",
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "from_timezone",
"type": "string",
"description": "源时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
},
{
"name": "to_timezone",
"type": "string",
"description": "目标时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
}
]
elif operation == "timestamp_to_datetime":
return [
{
"name": "input_value",
"type": "string",
"description": "输入值(时间字符串或时间戳)",
"required": True
},
{
"name": "output_format",
"type": "string",
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "to_timezone",
"type": "string",
"description": "目标时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
}
]
else:
# 默认返回所有参数除了operation
return [
{
"name": "input_value",
"type": "string",
"description": "输入值(时间字符串或时间戳)",
"required": False
},
{
"name": "input_format",
"type": "string",
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "output_format",
"type": "string",
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S",
"required": False,
"default": "%Y-%m-%d %H:%M:%S"
},
{
"name": "from_timezone",
"type": "string",
"description": "源时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
},
{
"name": "to_timezone",
"type": "string",
"description": "目标时区UTC, Asia/Shanghai",
"required": False,
"default": "Asia/Shanghai"
},
{
"name": "calculation",
"type": "string",
"description": "时间计算表达式(如:+1d, -2h, +30m",
"required": False
}
]
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
"""获取json_tool特定操作的参数"""
base_params = [
{
"name": "input_data",
"type": "string",
"description": "输入数据JSON字符串、YAML字符串或XML字符串",
"required": True
}
]
if operation == "insert":
return base_params + [
{
"name": "json_path",
"type": "string",
"description": "JSON路径表达式$.user.name或users[0].name",
"required": True
},
{
"name": "new_value",
"type": "string",
"description": "新值用于insert操作",
"required": True
}
]
elif operation == "replace":
return base_params + [
{
"name": "json_path",
"type": "string",
"description": "JSON路径表达式$.user.name或users[0].name",
"required": True
},
{
"name": "old_text",
"type": "string",
"description": "要替换的原文本用于replace操作",
"required": True
},
{
"name": "new_text",
"type": "string",
"description": "替换后的新文本用于replace操作",
"required": True
}
]
elif operation == "delete":
return base_params + [
{
"name": "json_path",
"type": "string",
"description": "JSON路径表达式$.user.name或users[0].name",
"required": True
}
]
elif operation == "parse":
return base_params + [
{
"name": "json_path",
"type": "string",
"description": "JSON路径表达式$.user.name或users[0].name",
"required": True
}
]
return base_params
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取自定义工具的方法"""

View File

@@ -7,7 +7,6 @@ User Memory Service
import os
import uuid
from collections import Counter
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
@@ -22,7 +21,269 @@ from sqlalchemy.orm import Session
logger = get_logger(__name__)
# Neo4j connector instan
# Neo4j connector instance for analytics functions
_neo4j_connector = Neo4jConnector()
# Default LLM ID for fallback
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
# ============================================================================
# Internal Helper Classes
# ============================================================================
class TagClassification(BaseModel):
"""Represents the classification of a tag into a specific domain."""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
def _get_llm_client_for_user(user_id: str):
"""
Get LLM client for a specific user based on their config.
Args:
user_id: User ID to get config for
Returns:
LLM client instance
"""
with get_db_context() as db:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
else:
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
class MemoryInsightHelper:
"""
Internal helper class for memory insight analysis.
Provides basic data retrieval and analysis functionality.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
self.llm_client = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""Calculate the distribution of memory domains based on hot tags."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True))
async def get_active_periods(self) -> list[int]:
"""
Identify the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
"""
query = """
MATCH (d:Dialogue)
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
continue
try:
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
continue
if not month_counts or valid_dates_count == 0:
return []
# Check if time distribution is too concentrated (likely batch imported data)
unique_months = len(month_counts)
if unique_months <= 2:
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
return []
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
if unique_months == 2:
counts = list(month_counts.values())
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
return []
async def get_social_connections(self) -> dict | None:
"""Find the user with whom the most memories are shared."""
query = """
MATCH (c1:Chunk {group_id: $group_id})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
if not records or not records[0].get("other_user_id"):
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
time_range_query = """
MATCH (c:Chunk)
WHERE c.group_id IN [$user_id, $other_user_id]
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(
time_range_query,
user_id=self.user_id,
other_user_id=most_connected_user
)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
class UserSummaryHelper:
"""
Internal helper class for user summary generation.
Provides data retrieval functionality for user summary analysis.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
self.llm = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.connector.close()
async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]:
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
records = []
for r in rows:
try:
records.append({
"statement": r.get("statement", ""),
"created_at": r.get("created_at")
})
except Exception:
continue
return records
async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Get meaningful entities and their frequencies using hot tag logic."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
return await get_hot_memory_tags(self.user_id, limit=limit)
# ============================================================================
# Service Class
# ============================================================================
# ============================================================================
# Service Class
# ============================================================================
class UserMemoryService:
@@ -601,7 +862,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
生成记忆洞察报告(四个维度)
这个函数包含完整的业务逻辑:
1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联)
1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联)
2. 使用 Jinja2 模板渲染提示词
3. 调用 LLM 生成四个维度的自然语言报告
4. 解析并返回四个部分
@@ -620,7 +881,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
import re
insight = MemoryInsight(end_user_id)
insight = MemoryInsightHelper(end_user_id)
try:
# 1. 并行获取三个维度的数据
@@ -722,7 +983,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
生成用户摘要(包含四个部分)
这个函数包含完整的业务逻辑:
1. 使用 UserSummary 工具类获取基础数据(实体、语句)
1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句)
2. 使用 prompt_utils 渲染提示词
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
@@ -737,20 +998,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
"one_sentence": str
}
"""
from app.core.memory.analytics.user_summary import UserSummary
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
import re
# 创建 UserSummary 实例
user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
try:
# 1) 收集上下文数据
entities = await user_summary_tool._get_top_entities(limit=40)
statements = await user_summary_tool._get_recent_statements(limit=100)
entities = await user_summary_tool.get_top_entities(limit=40)
statements = await user_summary_tool.get_recent_statements(limit=100)
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20]
# 2) 使用 prompt_utils 渲染提示词
user_prompt = await render_user_summary_prompt(
@@ -794,6 +1054,28 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
core_values = core_values_match.group(1).strip() if core_values_match else ""
one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else ""
# 6) 清理可能包含的反思内容(防御性编程)
# 如果 LLM 仍然输出了反思内容,在这里过滤掉
def clean_reflection_content(text: str) -> str:
"""移除可能包含的反思内容"""
if not text:
return text
# 移除 "---" 之后的所有内容(通常是反思部分的开始)
if '---' in text:
text = text.split('---')[0].strip()
# 移除 "**Step" 开头的内容
if '**Step' in text:
text = text.split('**Step')[0].strip()
# 移除 "Self-Review" 相关内容
if 'Self-Review' in text or 'self-review' in text:
text = re.sub(r'[\-\*]*\s*Self-Review.*$', '', text, flags=re.IGNORECASE | re.DOTALL).strip()
return text
user_summary = clean_reflection_content(user_summary)
personality = clean_reflection_content(personality)
core_values = clean_reflection_content(core_values)
one_sentence = clean_reflection_content(one_sentence)
return {
"user_summary": user_summary,
"personality": personality,

View File

@@ -17,6 +17,7 @@ from app.core.workflow.validator import validate_workflow_config
from app.db import get_db, get_db_context
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.end_user_repository import EndUserRepository
from app.services.multi_agent_service import convert_uuids_to_str
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
@@ -364,7 +365,7 @@ class WorkflowService:
execution.status = status
if output_data is not None:
execution.output_data = output_data
execution.output_data = convert_uuids_to_str(output_data)
if error_message is not None:
execution.error_message = error_message
if error_node_id is not None:

View File

@@ -8,7 +8,7 @@ import uuid
from typing import Dict, Any, Optional
from datetime import datetime
from app.models import AppRelease
from app.models import AppRelease, WorkflowConfig
from app.models.agent_app_config_model import AgentConfig
from app.models.multi_agent_model import MultiAgentConfig
@@ -28,7 +28,7 @@ class AgentConfigProxy:
def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
config_dict = release.config
agent_config = AgentConfig(
app_id=release.app_id,
system_prompt=config_dict.get("system_prompt"),
@@ -45,10 +45,10 @@ def agent_config_4_app_release(release: AppRelease ) -> AgentConfig:
def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
config_dict = release.config
agent_config = MultiAgentConfig(
app_id=release.app_id,
app_id=release.app_id,
default_model_config_id=release.default_model_config_id,
model_parameters=config_dict.get("model_parameters"),
master_agent_id=config_dict.get("master_agent_id"),
@@ -58,11 +58,29 @@ def multi_agent_config_4_app_release(release: AppRelease ) -> MultiAgentConfig:
routing_rules=config_dict.get("routing_rules"),
execution_config=config_dict.get("execution_config", {}),
aggregation_strategy=config_dict.get("aggregation_strategy", "merge"),
)
return agent_config
def workflow_config_4_app_release(release: AppRelease ) -> WorkflowConfig:
config_dict = release.config
config = WorkflowConfig(
id=release.id,
app_id=release.app_id,
nodes=config_dict.get("nodes", []),
edges=config_dict.get("edges", []),
variables=config_dict.get("variables", []),
execution_config=config_dict.get("execution_config", {}),
triggers=config_dict.get("triggers", [])
)
return config
def dict_to_multi_agent_config(config_dict: Dict[str, Any], app_id: Optional[uuid.UUID] = None):
"""Convert dict to MultiAgentConfig model object

View File

@@ -1,5 +1,6 @@
import { request } from '@/utils/request'
import type { AiPromptForm } from '@/views/ApplicationConfig/types'
import { handleSSE, type SSEMessage } from '@/utils/stream'
export const createPromptSessions = () => {
return request.post(`/prompt/sessions`)
@@ -7,6 +8,6 @@ export const createPromptSessions = () => {
export const getPrompt = (session_id: string) => {
return request.get(`/prompt/sessions/${session_id}`)
}
export const updatePromptMessages = (session_id: string, data: AiPromptForm) => {
return request.post(`/prompt/sessions/${session_id}/messages`, data)
export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => {
return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage)
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 936 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 568 B

View File

@@ -1224,6 +1224,8 @@ export const en = {
key_findings: 'Key Findings',
behavior_pattern: 'Behavior Pattern',
growth_trajectory: 'Growth Trajectory',
personality: 'Personality Traits',
core_values: 'Core Values',
},
space: {
createSpace: 'Create Space',
@@ -1799,12 +1801,20 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
"not_contains": 'Does Not Contain',
"startwith": 'Starts With',
"endwith": 'Ends With',
"eq": '==',
"ne": '!=',
"lt": '<',
"le": '<=',
"gt": '>',
"ge": '>=',
"eq": 'Equals',
"ne": 'Not Equals',
num: {
"eq": '=',
"ne": '',
"lt": '<',
"le": '≤',
"gt": '>',
"ge": '≥',
},
boolean: {
"eq": 'Is',
"ne": 'Is Not',
},
else_desc: 'Used to define the logic that should be executed when the if condition is not met.'
},
'http-request': {
@@ -1845,12 +1855,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
loop: {
cycle_vars: 'Loop Variables',
condition: 'Loop Termination Condition',
max_loop: 'Maximum Loop Count',
},
assigner: {
assignments: 'Variables',
cover: 'Overwrite',
cover: 'Override',
assign: 'Set',
clear: 'Clear'
clear: 'Clear',
add: '+=',
subtract: '-=',
multiply: '*=',
divide: '/=',
},
iteration: {
input: 'Input Variable',

View File

@@ -1305,6 +1305,8 @@ export const zh = {
key_findings: '关键发现',
behavior_pattern: '行为模式',
growth_trajectory: '成长轨迹',
personality: '性格特点',
core_values: '核心价值观',
},
space: {
createSpace: '创建空间',
@@ -1899,12 +1901,20 @@ export const zh = {
"not_contains": '不包含',
"startwith": '开始是',
"endwith": '结束是',
"eq": '==',
"ne": '!=',
"lt": '<',
"le": '<=',
"gt": '>',
"ge": '>=',
"eq": '',
"ne": '不是',
num: {
"eq": '=',
"ne": '',
"lt": '<',
"le": '≤',
"gt": '>',
"ge": '≥',
},
boolean: {
"eq": '是',
"ne": '不是',
},
else_desc: '用于定义当 if 条件不满足时应执行的逻辑。'
},
'http-request': {
@@ -1945,12 +1955,17 @@ export const zh = {
loop: {
cycle_vars: '循环变量',
condition: '循环终止条件',
max_loop: '最大循环次数',
},
assigner: {
assignments: '变量',
cover: '覆盖',
assign: '设置',
clear: '清空'
clear: '清空',
add: '+=',
subtract: '-=',
multiply: '*=',
divide: '/=',
},
iteration: {
input: '输入变量',

View File

@@ -16,6 +16,8 @@ import ConversationEmptyIcon from '@/assets/images/conversation/conversationEmpt
import type { ChatItem } from '@/components/Chat/types'
import CustomSelect from '@/components/CustomSelect'
import AiPromptVariableModal from './AiPromptVariableModal'
import { type SSEMessage } from '@/utils/stream'
import Editor from './Editor'
interface AiPromptModalProps {
refresh: (value: string) => void;
@@ -35,7 +37,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
const [variables, setVariables] = useState<string[]>([])
const [promptSession, setPromptSession] = useState<string | null>(null)
const aiPromptVariableModalRef = useRef<AiPromptVariableModalRef>(null)
const currentPromptRef = useRef<any>(null)
const editorRef = useRef<any>(null)
const currentPromptValueRef = useRef<string>('')
const values = Form.useWatch([], form)
@@ -78,16 +81,45 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
setChatList(prev => {
return [...prev, { role: 'user', content: messageContent}]
})
form.setFieldsValue({ message: undefined })
updatePromptMessages(promptSession, values)
.then(res => {
const response = res as { prompt: string; desc: string; variables: string[] }
form.setFieldsValue({ current_prompt: response.prompt })
setChatList(prev => {
return [...prev, { role: 'assistant', content: response.desc }]
})
setVariables(response.variables)
form.setFieldsValue({ message: undefined, current_prompt: undefined })
const handleStreamMessage = (data: SSEMessage[]) => {
data.map(item => {
const { content, desc, variables } = item.data as { content: string; desc: string; variables: string[] };
switch (item.event) {
case 'start':
currentPromptValueRef.current = ''
break;
case 'message':
if (content) {
currentPromptValueRef.current += content;
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
}
if (desc) {
setChatList(prev => {
return [...prev, { role: 'assistant', content: desc }]
})
}
if (variables) {
setVariables(variables)
}
break;
case 'end':
setLoading(false)
break
}
})
};
updatePromptMessages(promptSession, values, handleStreamMessage)
// .then(res => {
// const response = res as { prompt: string; desc: string; variables: string[] }
// form.setFieldsValue({ current_prompt: response.prompt })
// setChatList(prev => {
// return [...prev, { role: 'assistant', content: response.desc }]
// })
// setVariables(response.variables)
// })
.finally(() => {
setLoading(false)
})
@@ -101,18 +133,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
aiPromptVariableModalRef.current?.handleOpen()
}
const handleVariableApply = (value: string) => {
const textArea = currentPromptRef.current?.resizableTextArea?.textArea
if (textArea) {
const cursorPosition = textArea.selectionStart
const currentValue = values.current_prompt || ''
const newValue = currentValue.slice(0, cursorPosition) + value + currentValue.slice(cursorPosition)
form.setFieldValue('current_prompt', newValue)
// 设置新的光标位置
setTimeout(() => {
textArea.focus()
textArea.setSelectionRange(cursorPosition + value.length, cursorPosition + value.length)
}, 0)
if (editorRef.current?.insertText) {
editorRef.current.insertText(value)
} else {
form.setFieldValue('current_prompt', (values.current_prompt || '') + value)
}
@@ -191,7 +213,11 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
</Col>
</Row>
<Form.Item name="current_prompt">
<Input.TextArea ref={currentPromptRef} className="rb:bg-[#FBFDFF]! rb:h-100.5!" />
<Editor
ref={editorRef}
className="rb:h-100.5 "
onChange={(value) => form.setFieldValue('current_prompt', value)}
/>
</Form.Item>
<div className="rb:grid rb:grid-cols-2 rb:gap-4 rb:mt-6">
<Button block disabled={!values?.current_prompt} onClick={handleCopy}>{t('common.copy')}</Button>

View File

@@ -0,0 +1,91 @@
import {forwardRef, useImperativeHandle } from 'react';
import clsx from 'clsx';
import { LexicalComposer } from '@lexical/react/LexicalComposer';
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
import { $getSelection } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import InitialValuePlugin from './plugin/InitialValuePlugin'
import LineBreakPlugin from './plugin/LineBreakPlugin';
import InsertTextPlugin from './plugin/InsertTextPlugin';
export interface EditorRef {
insertText: (text: string) => void;
}
interface LexicalEditorProps {
className?: string;
placeholder?: string;
value?: string;
onChange?: (value: string) => void;
height?: number;
}
const theme = {
paragraph: 'editor-paragraph',
text: {
bold: 'editor-text-bold',
italic: 'editor-text-italic',
},
};
const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
className = '',
value,
placeholder = "请输入内容...",
onChange,
}, ref) => {
const [editor] = useLexicalComposerContext();
useImperativeHandle(ref, () => ({
insertText: (text: string) => {
editor.update(() => {
const selection = $getSelection();
if (selection) {
selection.insertText(text);
}
});
}
}), [editor]);
return (
<div style={{ position: 'relative' }}>
<RichTextPlugin
contentEditable={
<ContentEditable
className={clsx("rb:outline-none rb:resize-none rb:text-[14px] rb:leading-5 rb:px-4 rb:py-5 rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:overflow-auto", className)}
/>
}
placeholder={
<div className="rb:absolute rb:px-4 rb:py-5 rb:text-[14px] rb:text-[#5B6167] rb:leading-5 rb:pointer-none">
{placeholder}
</div>
}
ErrorBoundary={LexicalErrorBoundary}
/>
<LineBreakPlugin onChange={onChange} />
<InitialValuePlugin value={value} />
<InsertTextPlugin />
</div>
);
});
const Editor = forwardRef<EditorRef, LexicalEditorProps>((props, ref) => {
const initialConfig = {
namespace: 'Editor',
theme,
nodes: [],
onError: (error: Error) => {
console.error(error);
},
};
return (
<LexicalComposer initialConfig={initialConfig}>
<EditorContent {...props} ref={ref} />
</LexicalComposer>
);
});
export default Editor;

View File

@@ -0,0 +1,25 @@
import { type FC, useEffect } from 'react';
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
// 设置初始值的插件
const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => {
const [editor] = useLexicalComposerContext();
useEffect(() => {
if (value) {
editor.update(() => {
const root = $getRoot();
root.clear();
const paragraph = $createParagraphNode();
const textNode = $createTextNode(value);
paragraph.append(textNode);
root.append(paragraph);
});
}
}, [editor, value]);
return null;
};
export default InitialValuePlugin

View File

@@ -0,0 +1,24 @@
import { forwardRef, useImperativeHandle } from 'react';
import { $getSelection } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import type { EditorRef } from '../index'
// 插入文本的插件
const InsertTextPlugin = forwardRef<EditorRef>((_, ref) => {
const [editor] = useLexicalComposerContext();
useImperativeHandle(ref, () => ({
insertText: (text: string) => {
editor.update(() => {
const selection = $getSelection();
if (selection) {
selection.insertText(text);
}
});
}
}), [editor]);
return null;
});
export default InsertTextPlugin;

View File

@@ -0,0 +1,24 @@
import { type FC, useEffect } from 'react';
import { $getRoot } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
// 处理换行的插件
const LineBreakPlugin: FC<{ onChange?: (value: string) => void }> = ({ onChange }) => {
const [editor] = useLexicalComposerContext();
useEffect(() => {
return editor.registerUpdateListener(({ editorState }) => {
editorState.read(() => {
const root = $getRoot();
const textContent = root.getTextContent();
// 将\n转换为实际换行
const processedContent = textContent.replace(/\\n/g, '\n');
onChange?.(processedContent);
});
});
}, [editor, onChange]);
return null;
};
export default LineBreakPlugin;

View File

@@ -4,10 +4,9 @@ import {
Col,
Tag,
List,
Space
Flex
} from 'antd';
import { EyeOutlined } from '@ant-design/icons';
import clsx from 'clsx'
import { useTranslation } from 'react-i18next';
import dayjs, { type Dayjs } from 'dayjs'
@@ -103,9 +102,9 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getS
<div className="rb:h-full rb:flex rb:flex-col rb:justify-between">
<div className="rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167]">
{t(`tool.${item.config_data.tool_class}_features`)} <br />
<Space size={4} className="rb:mt-2">
<Flex gap={4} wrap className="rb:mt-2 rb:w-full">
{InnerConfigData[item.config_data.tool_class].features.map(vo => <Tag key={vo} color="default">{ t(`tool.${vo}`) }</Tag>) }
</Space>
</Flex>
{item.config_data.tool_class === 'DateTimeTool'
? <div className="rb:mt-3 rb:bg-[#F0F3F8] rb:px-3 rb:py-2.5 rb:rounded-md">

View File

@@ -5,16 +5,25 @@ import { Skeleton } from 'antd';
import RbCard from '@/components/RbCard/Card'
import Empty from '@/components/Empty';
import RbAlert from '@/components/RbAlert';
import {
getUserSummary,
} from '@/api/memory'
import type { AboutMeRef } from '../types'
interface Data {
user_summary: string;
personality: string;
core_values: string;
one_sentence: string;
[key: string]: string;
}
const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
const { t } = useTranslation()
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(false)
const [data, setData] = useState<string | null>(null)
const [data, setData] = useState<Data>({} as Data)
useEffect(() => {
if (!id) return
@@ -27,7 +36,7 @@ const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
setLoading(true)
getUserSummary(id)
.then((res) => {
setData((res as { summary?: string }).summary || null)
setData((res as Data) || null)
})
.finally(() => {
setLoading(false)
@@ -44,10 +53,29 @@ const AboutMe = forwardRef<AboutMeRef>((_props, ref) => {
>
{loading
? <Skeleton className="rb:mt-4" />
: data
? <div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
{data || '-'}
</div>
: Object.keys(data).filter(key => data[key] !== null).length > 0
? <>
{data.user_summary &&
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
{data.user_summary}
</div>
}
{data.personality && <>
<div className="rb:pt-4 rb:font-medium rb:leading-5 rb:mb-2">{t('userMemory.personality')}</div>
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
{data.personality}
</div>
</>}
{data.core_values && <>
<div className="rb:pt-4 rb:font-medium rb:leading-5 rb:mb-2">{t('userMemory.core_values')}</div>
<div className="rb:font-regular rb:leading-5 rb:text-[#5B6167]">
{data.core_values}
</div>
</>}
{data.one_sentence &&
<RbAlert className="rb:mt-4">{data.one_sentence}</RbAlert>
}
</>
: <Empty size={88} className="rb:mt-12 rb:mb-20.25" />
}
</RbCard>

View File

@@ -26,7 +26,6 @@ const ChatVariableModal = forwardRef<ChatVariableModalRef, ChatVariableModalProp
const [form] = Form.useForm<ChatVariable>();
const [loading, setLoading] = useState(false)
const [editIndex, setEditIndex] = useState<number | undefined>(undefined)
const typeValue = Form.useWatch('type', form);
// 封装取消方法,添加关闭弹窗逻辑
const handleClose = () => {

View File

@@ -14,18 +14,23 @@ const CharacterCountPlugin = ({ setCount, onChange }: { setCount: (count: number
let serializedContent = '';
// Traverse all nodes and serialize properly
const paragraphs: string[] = [];
root.getChildren().forEach(child => {
if ($isParagraphNode(child)) {
let paragraphContent = '';
child.getChildren().forEach(node => {
if ($isVariableNode(node)) {
serializedContent += node.getTextContent();
paragraphContent += node.getTextContent();
} else {
serializedContent += node.getTextContent();
paragraphContent += node.getTextContent();
}
});
paragraphs.push(paragraphContent);
}
});
serializedContent = paragraphs.join('\n');
setCount(serializedContent.length);
onChange?.(serializedContent);
});

View File

@@ -26,6 +26,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
parts.forEach(part => {
const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/);
const contextMatch = part.match(/^\{\{context\}\}$/);
const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/);
// 匹配{{context}}格式
if (contextMatch) {
@@ -38,6 +39,20 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
return
}
// 匹配{{conv.xx}}格式
if (conversationMatch) {
const [_, variableName] = conversationMatch;
const conversationSuggestion = options.find(s =>
s.group === 'CONVERSATION' && s.label === variableName
);
if (conversationSuggestion) {
paragraph.append($createVariableNode(conversationSuggestion));
} else {
paragraph.append($createTextNode(part));
}
return
}
// 匹配普通变量{{nodeId.label}}格式
if (match) {
const [_, nodeId, label] = match;

View File

@@ -13,13 +13,15 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const handleNodeSelect = (selectedNodeType: any) => {
const parentBBox = node.getBBox();
const cycleId = data.cycle;
const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: parentBBox.x,
y: parentBBox.y,
id,
data: {
id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),

View File

@@ -75,12 +75,15 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const parentBBox = node.getBBox();
const centerX = parentBBox.x + 24; // 默认节点宽度的一半
const centerY = parentBBox.y + 50; // 默认节点高度的一半
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: centerX,
y: centerY,
id: cycleStartNodeId,
data: {
id: cycleStartNodeId,
type: 'cycle-start',
parentId: node.id,
isDefault: true, // 标记为默认节点,不可删除

View File

@@ -43,12 +43,14 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
const newY = sourceBBox.y;
// 创建新节点
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: newX,
y: newY,
id,
data: {
id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
id,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${selectedNodeType.type}`),

View File

@@ -1,6 +1,6 @@
import { type FC } from 'react'
import { useTranslation } from 'react-i18next';
import { Form, Input, Button, Row, Col, Select } from 'antd'
import { Form, Input, Row, Col, Select, InputNumber, Radio } from 'antd'
import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import VariableSelect from '../VariableSelect'
@@ -11,6 +11,23 @@ interface AssignmentListProps {
options: Suggestion[];
}
const operationsObj = {
number: [
{ value: 'cover', label: 'workflow.config.assigner.cover' },
{ value: 'clear', label: 'workflow.config.assigner.clear' },
{ value: 'assign', label: 'workflow.config.assigner.assign' },
{ value: 'add', label: 'workflow.config.assigner.add' },
{ value: 'subtract', label: 'workflow.config.assigner.subtract' },
{ value: 'multiply', label: 'workflow.config.assigner.multiply' },
{ value: 'divide', label: 'workflow.config.assigner.divide' },
],
default: [
{ value: 'cover', label: 'workflow.config.assigner.cover' },
{ value: 'clear', label: 'workflow.config.assigner.clear' },
{ value: 'assign', label: 'workflow.config.assigner.assign' },
],
}
const AssignmentList: FC<AssignmentListProps> = ({
parentName,
options = [],
@@ -27,6 +44,11 @@ const AssignmentList: FC<AssignmentListProps> = ({
<PlusOutlined onClick={() => add({ operation: 'cover'})} />
</div>
{fields.map(({ key, name, ...restField }) => {
const variableSelector = form.getFieldValue([parentName, name, 'variable_selector']);
const selectedOption = options.find(option => `{{${option.value}}}` === variableSelector);
const dataType = selectedOption?.dataType;
const operationOptions = dataType === 'number' ? operationsObj.number : operationsObj.default;
return (
<div key={key} className="rb:mb-4">
<Row gutter={12} className="rb:mb-2!">
@@ -50,11 +72,10 @@ const AssignmentList: FC<AssignmentListProps> = ({
noStyle
>
<Select
options={[
{ value: 'cover', label: t('workflow.config.assigner.cover') },
{ value: 'clear', label: t('workflow.config.assigner.clear') },
{ value: 'assign', label: t('workflow.config.assigner.assign') },
]}
options={operationOptions.map(op => ({
...op,
label: t(op.label)
}))}
popupMatchSelectWidth={false}
onChange={() => {
form.setFieldValue([parentName, name, 'value'], undefined);
@@ -77,20 +98,31 @@ const AssignmentList: FC<AssignmentListProps> = ({
{...restField}
name={[name, 'value']}
noStyle
rules={[{ required: true, message: 'Missing last name' }]}
>
{operation === 'assign' ? (
<Input.TextArea
placeholder={t('common.pleaseEnter')}
rows={3}
/>
) : (
<VariableSelect
{operation === 'assign'
? <>
{dataType === 'number'
? <InputNumber
placeholder={t('common.pleaseEnter')}
className="rb:w-full!"
/>
: dataType === 'boolean'
? <Radio.Group block>
<Radio.Button value={true}>True</Radio.Button>
<Radio.Button value={false}>False</Radio.Button>
</Radio.Group>
: <Input.TextArea
placeholder={t('common.pleaseEnter')}
rows={3}
/>
}
</>
: <VariableSelect
placeholder={t('common.pleaseSelect')}
options={options}
options={dataType ? options.filter(vo => vo.dataType === dataType) : options}
popupMatchSelectWidth={false}
/>
)}
}
</Form.Item>
);
}}

View File

@@ -1,7 +1,7 @@
import { type FC } from 'react'
import clsx from 'clsx'
import { useTranslation } from 'react-i18next';
import { Form, Button, Select, Space, Row, Col, Divider } from 'antd'
import { Form, Button, Select, Space, Row, Col, Divider, InputNumber, Radio, type SelectProps } from 'antd'
import { DeleteOutlined } from '@ant-design/icons';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
@@ -9,37 +9,48 @@ import VariableSelect from '../VariableSelect'
import Editor from '../../Editor'
interface CaseListProps {
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; }[] }>;
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; input_type?: string; }[] }>;
onChange?: (value: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; comparison_operator: string; right: string; }[] }>) => void;
options: Suggestion[];
name: string;
selectedNode?: any;
graphRef?: any;
}
const operatorList = [
"empty",
"not_empty",
"contains",
"not_contains",
"startwith",
"endwith",
"eq",
"ne",
"lt",
"le",
"gt",
"ge"
]
const operatorsObj: { [key: string]: SelectProps['options'] } = {
default: [
{ value: 'empty', label: 'workflow.config.if-else.empty' },
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
{ value: 'contains', label: 'workflow.config.if-else.contains' },
{ value: 'not_contains', label: 'workflow.config.if-else.not_contains' },
{ value: 'startwith', label: 'workflow.config.if-else.startwith' },
{ value: 'endwith', label: 'workflow.config.if-else.endwith' },
{ value: 'eq', label: 'workflow.config.if-else.eq' },
{ value: 'ne', label: 'workflow.config.if-else.ne' },
],
number: [
{ value: 'eq', label: 'workflow.config.if-else.num.eq' },
{ value: 'ne', label: 'workflow.config.if-else.num.ne' },
{ value: 'lt', label: 'workflow.config.if-else.num.lt' },
{ value: 'le', label: 'workflow.config.if-else.num.le' },
{ value: 'gt', label: 'workflow.config.if-else.num.gt' },
{ value: 'ge', label: 'workflow.config.if-else.num.ge' },
{ value: 'empty', label: 'workflow.config.if-else.empty' },
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
],
boolean: [
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
]
}
const CaseList: FC<CaseListProps> = ({
value = [],
options,
name,
onChange,
selectedNode,
graphRef
}) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
if (!selectedNode || !graphRef?.current) return;
@@ -175,29 +186,49 @@ const CaseList: FC<CaseListProps> = ({
});
}, 50);
};
const handleChangeLogicalOperator = (index: number) => {
const newValue = [...value]
newValue[index] = {
...newValue[index],
logical_operator: newValue[index].logical_operator === 'and' ? 'or' : 'and'
}
onChange && onChange(newValue)
}
const currentValue = form.getFieldValue([name, index, 'logical_operator']);
form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
};
const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string) => {
form.setFieldsValue({
[name]: {
[caseIndex]: {
expressions: {
[conditionIndex]: {
left: newValue,
comparison_operator: undefined,
right: undefined,
input_type: undefined
}
}
}
}
});
};
const handleAddCase = (addCaseFunc: Function) => {
addCaseFunc({ logical_operator: 'and', expressions: [] });
setTimeout(() => {
updateNodePorts((value?.length || 0) + 1);
const currentCases = form.getFieldValue(name) || [];
updateNodePorts(currentCases.length);
}, 100);
};
const handleRemoveCase = (removeCaseFunc: Function, fieldName: number, caseIndex: number) => {
removeCaseFunc(fieldName);
setTimeout(() => {
updateNodePorts((value?.length || 1) - 1, caseIndex);
const currentCases = form.getFieldValue(name) || [];
updateNodePorts(currentCases.length, caseIndex);
}, 100);
};
const handleInputTypeChange = (caseIndex: number, conditionIndex: number) => {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], undefined);
};
return (
<>
<Form.List name={name}>
@@ -218,7 +249,7 @@ const CaseList: FC<CaseListProps> = ({
<Space>
<Button
type="dashed"
onClick={() => addCondition()}
onClick={() => addCondition({})}
size="small"
>
+ {t('workflow.config.addCase')}
@@ -234,15 +265,23 @@ const CaseList: FC<CaseListProps> = ({
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-15 rb:bottom-6 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
<Form.Item name={[caseField.name, 'logical_operator']} noStyle >
<Button size="small" className="rb:cursor-pointer" onClick={() => handleChangeLogicalOperator(caseIndex)}>{value?.[caseIndex].logical_operator}</Button>
<Button size="small" className="rb:cursor-pointer" onClick={() => handleChangeLogicalOperator(caseIndex)}>{logicalOperator}</Button>
</Form.Item>
</div>
</>
}
{conditionFields.map((conditionField, conditionIndex) => {
const currentOperator = value?.[caseIndex]?.expressions?.[conditionIndex]?.comparison_operator;
const cases = form.getFieldValue(name) || [];
const currentCase = cases[caseIndex] || {};
const currentExpression = currentCase.expressions?.[conditionIndex] || {};
const currentOperator = currentExpression.comparison_operator;
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
const leftFieldValue = currentExpression.left;
const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue);
const leftFieldType = leftFieldOption?.dataType;
const operatorList = operatorsObj[leftFieldType || 'default'] || operatorsObj.default || [];
const inputType = leftFieldType === 'number' ? currentExpression.input_type : undefined;
const logicalOperator = currentCase.logical_operator;
return (
<div key={conditionField.key} className={clsx({
"rb:mb-3": conditionIndex !== conditionFields.length - 1
@@ -257,18 +296,20 @@ const CaseList: FC<CaseListProps> = ({
size="small"
allowClear={false}
popupMatchSelectWidth={false}
onChange={(val) => handleLeftFieldChange(caseIndex, conditionIndex, val)}
/>
</Form.Item>
</Col>
<Col span={8}>
<Form.Item name={[conditionField.name, 'comparison_operator']} noStyle>
<Select
options={operatorList.map(key => ({
value: key,
label: t(`workflow.config.if-else.${key}`)
options={operatorList.map(vo => ({
...vo,
label: t(String(vo?.label || ''))
}))}
size="small"
popupMatchSelectWidth={false}
placeholder={t('common.pleaseSelect')}
/>
</Form.Item>
</Col>
@@ -280,11 +321,48 @@ const CaseList: FC<CaseListProps> = ({
</Col>
</Row>
{!hideRightField && (
<Form.Item name={[conditionField.name, 'right']} noStyle>
<Editor options={options} />
</Form.Item>
)}
{!hideRightField && <>
{leftFieldType === 'number'
? <Row>
<Col span={12}>
<Form.Item name={[conditionField.name, 'input_type']} noStyle>
<Select
placeholder={t('common.pleaseSelect')}
options={[{ value: 'Variable', label: 'Variable' }, { value: 'Constant', label: 'Constant' }]}
popupMatchSelectWidth={false}
variant="borderless"
onChange={() => handleInputTypeChange(caseIndex, conditionIndex)}
/>
</Form.Item>
</Col>
<Col span={12}>
<Form.Item name={[conditionField.name, 'right']} noStyle>
{inputType === 'Variable'
?
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options.filter(vo => vo.dataType === 'number')}
allowClear={false}
popupMatchSelectWidth={false}
variant="borderless"
/>
: <InputNumber placeholder={t('common.pleaseEnter')}
variant="borderless" className="rb:w-full!" />
}
</Form.Item>
</Col>
</Row>
: <Form.Item name={[conditionField.name, 'right']} noStyle>
{leftFieldType === 'boolean'
? <Radio.Group block>
<Radio.Button value={true}>True</Radio.Button>
<Radio.Button value={false}>False</Radio.Button>
</Radio.Group>
: <Editor options={options} />
}
</Form.Item>
}
</>}
</div>
</div>
)

View File

@@ -1,7 +1,6 @@
import { type FC } from 'react'
import clsx from 'clsx'
import { useTranslation } from 'react-i18next';
import { Form, Button, Select, Space, Row, Col, Divider } from 'antd'
import { Form, Button, Select, Row, Col, InputNumber, Radio, type SelectProps } from 'antd'
import { DeleteOutlined } from '@ant-design/icons';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
@@ -10,7 +9,7 @@ import Editor from '../../Editor'
interface Case {
logical_operator: 'and' | 'or';
expressions: Array<{ left: string; comparison_operator: string; right: string; }>
expressions: Array<{ left: string; comparison_operator: string; right: string; input_type: string; }>
}
interface CaseListProps {
@@ -22,36 +21,63 @@ interface CaseListProps {
graphRef?: any;
addBtnText?: string;
}
const operatorList = [
"empty",
"not_empty",
"contains",
"not_contains",
"startwith",
"endwith",
"eq",
"ne",
"lt",
"le",
"gt",
"ge"
]
const operatorsObj: { [key: string]: SelectProps['options'] } = {
default: [
{ value: 'empty', label: 'workflow.config.if-else.empty' },
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
{ value: 'contains', label: 'workflow.config.if-else.contains' },
{ value: 'not_contains', label: 'workflow.config.if-else.not_contains' },
{ value: 'startwith', label: 'workflow.config.if-else.startwith' },
{ value: 'endwith', label: 'workflow.config.if-else.endwith' },
{ value: 'eq', label: 'workflow.config.if-else.eq' },
{ value: 'ne', label: 'workflow.config.if-else.ne' },
],
number: [
{ value: 'eq', label: 'workflow.config.if-else.num.eq' },
{ value: 'ne', label: 'workflow.config.if-else.num.ne' },
{ value: 'lt', label: 'workflow.config.if-else.num.lt' },
{ value: 'le', label: 'workflow.config.if-else.num.le' },
{ value: 'gt', label: 'workflow.config.if-else.num.gt' },
{ value: 'ge', label: 'workflow.config.if-else.num.ge' },
{ value: 'empty', label: 'workflow.config.if-else.empty' },
{ value: 'not_empty', label: 'workflow.config.if-else.not_empty' },
],
boolean: [
{ value: 'eq', label: 'workflow.config.if-else.boolean.eq' },
{ value: 'ne', label: 'workflow.config.if-else.boolean.ne' },
]
}
const ConditionList: FC<CaseListProps> = ({
value,
options,
parentName,
onChange,
}) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
const handleLeftFieldChange = (index: number, newValue: string) => {
form.setFieldsValue({
[parentName]: {
expressions: {
[index]: {
left: newValue,
comparison_operator: undefined,
right: undefined,
input_type: undefined
}
}
}
});
};
const handleInputTypeChange = (index: number) => {
form.setFieldValue([parentName, 'expressions', index, 'right'], undefined);
};
const handleChangeLogicalOperator = () => {
if (!value) return;
onChange && onChange({
logical_operator: value.logical_operator === 'and' ? 'or' : 'and',
expressions: value.expressions || []
})
}
const currentValue = form.getFieldValue([parentName, 'logical_operator']);
form.setFieldValue([parentName, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
};
return (
<>
<Form.List name={[parentName, 'expressions']}>
@@ -59,8 +85,16 @@ const ConditionList: FC<CaseListProps> = ({
<div>
<div className="rb:relative">
{fields.map((field, index) => {
const currentOperator = value?.expressions?.[index]?.comparison_operator;
const expressions = form.getFieldValue([parentName, 'expressions']) || [];
const currentExpression = expressions[index] || {};
const currentOperator = currentExpression.comparison_operator;
const hideRightField = currentOperator === 'empty' || currentOperator === 'not_empty';
const leftFieldValue = currentExpression.left;
const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue);
const leftFieldType = leftFieldOption?.dataType;
const operatorList = operatorsObj[leftFieldType || 'default'] || operatorsObj.default || [];
const inputType = leftFieldType === 'number' ? currentExpression.input_type : undefined;
const logicalOperator = form.getFieldValue([parentName, 'logical_operator']);
return (
<div key={field.key} className="rb:mb-3">
@@ -68,7 +102,7 @@ const ConditionList: FC<CaseListProps> = ({
<div className="rb:absolute rb:w-3 rb:left-2 rb:top-3.75 rb:bottom-3.75 rb:z-10 rb:border rb:border-[#DFE4ED] rb:rounded-l-md rb:border-r-0"></div>
<div className="rb:absolute rb:z-10 rb:left-0 rb:top-[50%] rb:transform-[translateY(-50%)]]">
<Form.Item name={[parentName, 'logical_operator']} noStyle >
<Button size="small" className="rb:cursor-pointer" onClick={handleChangeLogicalOperator}>{value?.logical_operator}</Button>
<Button size="small" className="rb:cursor-pointer" onClick={handleChangeLogicalOperator}>{logicalOperator}</Button>
</Form.Item>
</div>
</>)}
@@ -82,6 +116,7 @@ const ConditionList: FC<CaseListProps> = ({
size="small"
allowClear={false}
popupMatchSelectWidth={false}
onChange={(val) => handleLeftFieldChange(index, val)}
/>
</Form.Item>
</Col>
@@ -89,9 +124,9 @@ const ConditionList: FC<CaseListProps> = ({
<Col span={8}>
<Form.Item name={[field.name, 'comparison_operator']} noStyle>
<Select
options={operatorList.map(key => ({
value: key,
label: t(`workflow.config.if-else.${key}`)
options={operatorList.map(vo => ({
...vo,
label: t(String(vo?.label || ''))
}))}
size="small"
popupMatchSelectWidth={false}
@@ -104,14 +139,53 @@ const ConditionList: FC<CaseListProps> = ({
onClick={() => remove(field.name)}
/>
</Col>
{!hideRightField && (
<Col span={24}>
<Form.Item name={[field.name, 'right']} noStyle>
<Editor options={options} />
</Form.Item>
</Col>
)}
{!hideRightField && <>
{leftFieldType === 'number'
? <Col span={24}><Row>
<Col span={12}>
<Form.Item name={[field.name, 'input_type']} noStyle>
<Select
placeholder={t('common.pleaseSelect')}
options={[{ value: 'Variable', label: 'Variable' }, { value: 'Constant', label: 'Constant' }]}
popupMatchSelectWidth={false}
variant="borderless"
className="rb:w-full!"
onChange={() => handleInputTypeChange(index)}
/>
</Form.Item>
</Col>
<Col span={12}>
<Form.Item name={[field.name, 'right']} noStyle>
{inputType === 'Variable'
?
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options.filter(vo => vo.dataType === 'number')}
allowClear={false}
popupMatchSelectWidth={false}
variant="borderless"
className="rb:w-full!"
/>
: <InputNumber placeholder={t('common.pleaseEnter')}
variant="borderless" className="rb:w-full!" />
}
</Form.Item>
</Col>
</Row></Col>
: <Col span={24}>
<Form.Item name={[field.name, 'right']} noStyle>
{leftFieldType === 'boolean'
? <Radio.Group block>
<Radio.Button value={true}>True</Radio.Button>
<Radio.Button value={false}>False</Radio.Button>
</Radio.Group>
: <Editor options={options} />
}
</Form.Item>
</Col>
}
</>}
</Row>
</div>

View File

@@ -65,7 +65,7 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
label: `${childData.name || childData.type}.${key}`,
type: 'output',
dataType: 'string',
value: `{{${childData.id}.${key}}}`,
value: `${childData.id}.${key}`,
nodeData: childData
});
}

View File

@@ -25,7 +25,6 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
<Row gutter={12} className="rb:mb-2!">
<Col span={12}>
<Form.Item
name={[name,0, 'key']}
noStyle
>
{t('workflow.config.var-aggregator.variable')}
@@ -34,9 +33,8 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
</Row>
<Form.Item
name={[name, 0, 'value']}
name={name}
noStyle
rules={[{ required: true, message: 'Missing last name' }]}
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
@@ -76,7 +74,6 @@ const GroupVariableList: FC<GroupVariableListProps> = ({
{...restField}
name={[name, 'value']}
noStyle
rules={[{ required: true, message: 'Missing last name' }]}
>
<VariableSelect
placeholder={t('common.pleaseSelect')}

View File

@@ -1,4 +1,4 @@
import { useState, useEffect } from 'react';
import { useState, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'
import { Button, Select, Table } from 'antd';
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons';
@@ -33,104 +33,90 @@ const EditableTable: React.FC<EditableTableProps> = ({
const [rows, setRows] = useState<TableRow[]>([]);
useEffect(() => {
console.log('EditableTable value', value)
if (Array.isArray(value)) {
setRows([...value])
} else if (value && Object.keys(value).length > 0) {
// Only update if rows are empty or significantly different
const valueEntries = Object.entries(value)
if (rows.length === 0 || rows.length !== valueEntries.length) {
setRows(valueEntries.map(([key, val], index) => {
console.log('val', val)
return {
key: index.toString(),
name: key || '',
value: val || '',
type: typeOptions.length > 0 ? typeOptions[0].value : undefined
}
}))
}
setRows(Object.entries(value).map(([key, val], index) => ({
key: index.toString(),
name: key || '',
value: val || '',
type: typeOptions.length > 0 ? typeOptions[0].value : undefined
})))
} else {
setRows([])
}
}, [JSON.stringify(value), typeOptions.length])
}, [value, typeOptions])
const handleChange = (key: string, field: 'name' | 'value' | 'type', val: string) => {
const newRows = [...rows.map(row =>
const newRows = rows.map(row =>
row.key === key ? { ...row, [field]: val } : row
)];
);
setRows(newRows);
onChange?.(newRows);
};
const handleAdd = () => {
const newKey = Date.now().toString();
if (typeOptions.length) {
setRows([...rows, { key: newKey, name: '', value: '', type: typeOptions[0].value }]);
} else {
setRows([...rows, { key: newKey, name: '', value: '' }]);
}
const newRow: TableRow = {
key: Date.now().toString(),
name: '',
value: '',
...(typeOptions.length > 0 && { type: typeOptions[0].value })
};
const newRows = [...rows, newRow];
setRows(newRows);
onChange?.(newRows);
};
const handleDelete = (key: string, index: number) => {
console.log('index', index)
if (rows.length === 1) {
setRows([]);
onChange?.([]);
} else {
const newRows = rows.filter(row => row.key !== key);
setRows(newRows);
onChange?.(newRows);
}
const handleDelete = (key: string) => {
const newRows = rows.filter(row => row.key !== key);
setRows(newRows);
onChange?.(newRows);
};
const columns = typeOptions?.length > 0 ? [
{
title: t('workflow.config.name'),
dataIndex: 'name',
width: '45%',
render: (text: string, record: TableRow) => (
<Editor
options={options}
value={text}
height={32}
variant="outlined"
onChange={(value) => handleChange(record.key, 'name', value)}
/>
),
},
{
title: t('workflow.config.type'),
dataIndex: 'type',
width: '20%',
render: (text: string, record: TableRow) => (
<Select
value={text}
options={typeOptions}
onChange={(value) => {
console.log('value record', value)
handleChange(record.key, 'type', value)
}}
/>
),
},
{
title: t('workflow.config.value'),
const columns = useMemo(() => {
const baseColumns = [
{
title: typeOptions.length > 0 ? t('workflow.config.name') : '键',
dataIndex: 'name',
width: typeOptions.length > 0 ? '35%' : '45%',
render: (text: string, record: TableRow) => (
<Editor
options={options}
value={text}
height={32}
variant="outlined"
onChange={(value) => handleChange(record.key, 'name', value || '')}
/>
),
}
];
if (typeOptions.length > 0) {
baseColumns.push({
title: t('workflow.config.type'),
dataIndex: 'type',
width: '20%',
render: (text: string, record: TableRow) => (
<Select
value={text}
options={typeOptions}
onChange={(value) => handleChange(record.key, 'type', value)}
/>
),
});
}
baseColumns.push({
title: typeOptions.length > 0 ? t('workflow.config.value') : '值',
dataIndex: 'value',
width: '45%',
width: typeOptions.length > 0 ? '35%' : '45%',
render: (text: string, record: TableRow) => {
if (record.type === 'file') {
return (
<VariableSelect
options={options}
value={text}
onChange={(value) => {
console.log('value record', value)
handleChange(record.key, 'value', value)
}}
onChange={(value) => handleChange(record.key, 'value', value || '')}
/>
)
}
@@ -140,78 +126,41 @@ const EditableTable: React.FC<EditableTableProps> = ({
value={text}
height={32}
variant="outlined"
onChange={(value) => {
console.log('value record', value)
handleChange(record.key, 'value', value)
}}
onChange={(value) => handleChange(record.key, 'value', value || '')}
/>
)
},
},
{
});
baseColumns.push({
title: '',
dataIndex: 'actions',
width: '10%',
render: (_: any, record: TableRow, index: number) => (
render: (_: any, record: TableRow) => (
<Button
type="text"
icon={<DeleteOutlined />}
onClick={() => handleDelete(record.key, index)}
onClick={() => handleDelete(record.key)}
/>
),
},
] : [
{
title: '键',
dataIndex: 'name',
width: '45%',
render: (text: string, record: TableRow) => (
<Editor
options={options}
value={text}
height={32}
variant="outlined"
onChange={(value) => handleChange(record.key, 'name', value)}
/>
),
},
{
title: '值',
dataIndex: 'value',
width: '45%',
render: (text: string, record: TableRow) => (
<Editor
options={options}
value={text}
height={32}
variant="outlined"
onChange={(value) => handleChange(record.key, 'value', value)}
/>
),
},
{
title: '',
width: '10%',
render: (_: any, record: TableRow, index: number) => (
<Button
type="text"
icon={<DeleteOutlined />}
onClick={() => handleDelete(record.key, index)}
/>
),
},
];
});
return baseColumns;
}, [typeOptions, options, t]);
return (
<div className="rb:mb-4">
{title && <div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
<div className="rb:font-medium">{title}</div>
<Button
type="text"
icon={<PlusOutlined />}
onClick={handleAdd}
size="small"
/>
</div>}
{title && (
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
<div className="rb:font-medium">{title}</div>
<Button
type="text"
icon={<PlusOutlined />}
onClick={handleAdd}
size="small"
/>
</div>
)}
<Table
columns={columns}
dataSource={rows}
@@ -220,11 +169,11 @@ const EditableTable: React.FC<EditableTableProps> = ({
locale={{ emptyText: <Empty size={88} /> }}
scroll={{ x: 'max-content' }}
/>
{!title &&
{!title && (
<Button type="dashed" onClick={handleAdd} block className='rb:mt-1'>
+{t('common.add')}
</Button>
}
)}
</div>
);
};

View File

@@ -1,6 +1,6 @@
import { type FC, useEffect, useRef } from "react";
import { type FC, useRef } from "react";
import { useTranslation } from 'react-i18next'
import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input, Slider } from 'antd'
import { Form, Row, Col, Select, Button, Divider, InputNumber, Switch, Input } from 'antd'
import Editor from '../../Editor'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import AuthConfigModal from './AuthConfigModal'

View File

@@ -128,29 +128,32 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
<List
grid={{ gutter: 12, column: 1 }}
dataSource={knowledgeList}
renderItem={(item) => (
<List.Item>
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
<div className="rb:font-medium rb:leading-4">
{item.name}
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
</Tag>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
renderItem={(item) => {
if (!item.id) return null
return (
<List.Item>
<div key={item.id} className="rb:flex rb:items-center rb:justify-between rb:p-[12px_16px] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg">
<div className="rb:font-medium rb:leading-4">
{item.name}
<Tag color={item.status === 1 ? 'success' : item.status === 0 ? 'default' : 'error'} className="rb:ml-2">
{item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')}
</Tag>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-5">{t('application.contains', {include_count: item.doc_num})}</div>
</div>
<Space size={12}>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
onClick={() => handleEditKnowledge(item)}
></div>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
onClick={() => handleDeleteKnowledge(item.id)}
></div>
</Space>
</div>
<Space size={12}>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
onClick={() => handleEditKnowledge(item)}
></div>
<div
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
onClick={() => handleDeleteKnowledge(item.id)}
></div>
</Space>
</div>
</List.Item>
)}
</List.Item>
)
}}
/>
}
{/* 全局设置 */}

View File

@@ -1,12 +1,15 @@
import React from 'react';
import { useTranslation } from 'react-i18next'
import { MinusCircleOutlined } from '@ant-design/icons';
import { Button, Form, Input, Space } from 'antd';
import { Button, Form, Input, Space, Row, Col } from 'antd';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import VariableSelect from '../VariableSelect'
interface MappingListProps {
name: string;
options: Suggestion[];
}
const MappingList: React.FC<MappingListProps> = ({ name }) => {
const MappingList: React.FC<MappingListProps> = ({ name, options }) => {
const { t } = useTranslation()
return (
<>
@@ -14,23 +17,33 @@ const MappingList: React.FC<MappingListProps> = ({ name }) => {
{(fields, { add, remove }) => (
<>
{fields.map(({ key, name, ...restField }) => (
<Space key={key} style={{ display: 'flex', marginBottom: 8 }} align="baseline">
<Form.Item
{...restField}
name={[name, 'name']}
noStyle
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
<Form.Item
{...restField}
name={[name, 'value']}
noStyle
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
<MinusCircleOutlined onClick={() => remove(name)} />
</Space>
<Row gutter={12} className="rb:mb-2">
<Col span={10}>
<Form.Item
{...restField}
name={[name, 'name']}
noStyle
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
</Col>
<Col span={12}>
<Form.Item
{...restField}
name={[name, 'value']}
noStyle
>
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={options}
popupMatchSelectWidth={false}
/>
</Form.Item>
</Col>
<Col span={2}>
<MinusCircleOutlined onClick={() => remove(name)} />
</Col>
</Row>
))}
<Form.Item>
<Button type="dashed" onClick={() => add()} block>

View File

@@ -29,15 +29,15 @@ const MessageEditor: FC<TextareaProps> = ({
}) => {
const { t } = useTranslation()
const form = Form.useFormInstance();
const values = form.getFieldsValue()
const values = Form.useWatch([], form);
// 检查是否已经使用了context变量将已使用的context设置为disabled
const processedOptions = useMemo(() => {
if (!isArray || !values[parentName]) return options;
if (!isArray || !values?.[parentName]) return options;
// 获取所有消息内容
const allContents = values[parentName]
.map((msg: any) => msg.content || '')
.map((msg: any) => msg?.content || '')
.join(' ');
// 将已使用的context变量标记为disabled
@@ -50,83 +50,74 @@ const MessageEditor: FC<TextareaProps> = ({
}, [options, values, parentName, isArray]);
const handleAdd = (add: FormListOperation['add']) => {
const list = values[parentName];
const lastRole = list[list.length - 1].role
const list = values?.[parentName] || [];
const lastRole = list.length > 0 ? list[list.length - 1]?.role : 'ASSISTANT';
add({
role: lastRole === 'USER' ? 'ASSISTANT' : 'USER',
content: undefined
})
content: ''
});
};
if (!isArray) {
return (
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
<Row>
<Col span={12}>
{title ?? t('workflow.answerDesc')}
</Col>
</Row>
<Form.Item name={parentName} noStyle>
<Editor placeholder={placeholder} options={processedOptions} />
</Form.Item>
</Space>
);
}
return (
<div>
{isArray
? <Form.List name={parentName}>
{(fields, { add, remove }) => (
<Space size={12} direction="vertical" className="rb:w-full">
{fields.map(({ key, name, ...restField }) => {
const currentRole = (values[parentName]?.[key].role || 'USER').toUpperCase()
return (
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
<Row>
<Col span={12}>
<Form.Item
{...restField}
name={[name, 'role']}
noStyle
>
{currentRole === 'SYSTEM'
? <Input disabled />
:
<Select
options={roleOptions}
disabled={currentRole === 'SYSTEM'}
/>
}
</Form.Item>
</Col>
{currentRole !== 'SYSTEM' && <Col span={12}>
<div className="rb:h-full rb:flex rb:justify-end rb:items-center">
<MinusCircleOutlined onClick={() => remove(name)} />
</div>
</Col>}
</Row>
<Form.Item
{...restField}
name={[name, 'content']}
noStyle
>
<Editor placeholder={placeholder} options={processedOptions} />
<Form.List name={parentName}>
{(fields, { add, remove }) => (
<Space size={12} direction="vertical" className="rb:w-full">
{fields.map(({ key, name, ...restField }) => {
const currentRole = (values?.[parentName]?.[name]?.role || 'USER').toUpperCase();
return (
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
<Row>
<Col span={12}>
<Form.Item {...restField} name={[name, 'role']} noStyle>
{currentRole === 'SYSTEM' ? (
<Input disabled />
) : (
<Select
options={roleOptions}
disabled={currentRole === 'SYSTEM'}
/>
)}
</Form.Item>
</Space>
)
})}
<Form.Item>
<Button type="dashed" onClick={() => handleAdd(add)} block>
+{t('workflow.addMessage')}
</Button>
</Form.Item>
</Space >
)}
</Form.List>
:
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
<Row>
<Col span={12}>
{title ?? t('workflow.answerDesc')}
</Col>
</Row>
<Form.Item
name={parentName}
noStyle
>
<Editor placeholder={placeholder} options={processedOptions} />
</Col>
{currentRole !== 'SYSTEM' && (
<Col span={12}>
<div className="rb:h-full rb:flex rb:justify-end rb:items-center">
<MinusCircleOutlined onClick={() => remove(name)} />
</div>
</Col>
)}
</Row>
<Form.Item {...restField} name={[name, 'content']} noStyle>
<Editor placeholder={placeholder} options={processedOptions} />
</Form.Item>
</Space>
);
})}
<Form.Item>
<Button type="dashed" onClick={() => handleAdd(add)} block>
+{t('workflow.addMessage')}
</Button>
</Form.Item>
</Space>
}
</div>
)}
</Form.List>
);
};

View File

@@ -85,15 +85,6 @@ const Properties: FC<PropertiesProps> = ({
const { id, knowledge_retrieval, group, group_names, ...rest } = values
const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {}
let groupNames: Record<string, string[]> | string[] = {}
if (group && group_names?.length) {
group_names.forEach(vo => {
(groupNames as Record<string, string[]>)[vo.key] = vo.value
})
} else if (!group) {
groupNames = group_names?.[0]?.value || []
}
let allRest = {
...rest,
...restKnowledgeConfig,
@@ -107,7 +98,14 @@ const Properties: FC<PropertiesProps> = ({
Object.keys(values).forEach(key => {
if (selectedNode.data?.config?.[key]) {
selectedNode.data.config[key].defaultValue = values[key]
// Create a deep copy to avoid reference sharing between nodes
if (!selectedNode.data.config[key]) {
selectedNode.data.config[key] = {};
}
selectedNode.data.config[key] = {
...selectedNode.data.config[key],
defaultValue: values[key]
};
}
})
@@ -194,7 +192,7 @@ const Properties: FC<PropertiesProps> = ({
const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id);
const childNodeIds = getChildNodes(selectedNode.id);
console.log('childNodeIds', childNodeIds)
console.log('childNodeIds', selectedNode, childNodeIds)
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
allRelevantNodeIds.forEach(nodeId => {
@@ -219,7 +217,7 @@ const Properties: FC<PropertiesProps> = ({
label: variable.name,
type: 'variable',
dataType: variable.type,
value: `{{${nodeId}.${variable.name}}}`,
value: `${node.getData().id}.${variable.name}`,
nodeData: nodeData,
});
}
@@ -249,7 +247,7 @@ const Properties: FC<PropertiesProps> = ({
label: 'output',
type: 'variable',
dataType: 'String',
value: `${nodeId}.output`,
value: `${node.getData().id}.output`,
nodeData: nodeData,
});
}
@@ -263,7 +261,104 @@ const Properties: FC<PropertiesProps> = ({
label: 'message',
type: 'variable',
dataType: 'array[object]',
value: `${nodeId}.message`,
value: `${node.getData().id}.message`,
nodeData: nodeData,
});
}
break
case 'parameter-extractor':
const successKey = `${nodeId}___is_success`;
const reasonKey = `${nodeId}___reason`;
if (!addedKeys.has(successKey)) {
addedKeys.add(successKey);
variableList.push({
key: successKey,
label: '__is_success',
type: 'variable',
dataType: 'number',
value: `${node.getData().id}.__is_success`,
nodeData: nodeData,
});
}
if (!addedKeys.has(reasonKey)) {
addedKeys.add(reasonKey);
variableList.push({
key: reasonKey,
label: '__reason',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.__reason`,
nodeData: nodeData,
});
}
// Add params variables
const paramsList = nodeData.config?.params?.defaultValue || [];
paramsList.forEach((param: any) => {
if (!param || !param?.name) return;
const paramKey = `${nodeId}_${param.name}`;
if (!addedKeys.has(paramKey)) {
addedKeys.add(paramKey);
variableList.push({
key: paramKey,
label: param.name,
type: 'variable',
dataType: param.type || 'string',
value: `${node.getData().id}.${param.name}`,
nodeData: nodeData,
});
}
});
break
case 'var-aggregator':
const varAggregatorKey = `${nodeId}_output`;
if (!addedKeys.has(varAggregatorKey)) {
addedKeys.add(varAggregatorKey);
variableList.push({
key: varAggregatorKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.output`,
nodeData: nodeData,
});
}
break
case 'http-request':
const httpBodyKey = `${nodeId}_body`;
const httpStatusKey = `${nodeId}_status_code`;
if (!addedKeys.has(httpBodyKey)) {
addedKeys.add(httpBodyKey);
variableList.push({
key: httpBodyKey,
label: 'body',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.body`,
nodeData: nodeData,
});
}
if (!addedKeys.has(httpStatusKey)) {
addedKeys.add(httpStatusKey);
variableList.push({
key: httpStatusKey,
label: 'status_code',
type: 'variable',
dataType: 'number',
value: `${node.getData().id}.status_code`,
nodeData: nodeData,
});
}
break
case 'jinja-render':
const jinjaOutputKey = `${nodeId}_output`;
if (!addedKeys.has(jinjaOutputKey)) {
addedKeys.add(jinjaOutputKey);
variableList.push({
key: jinjaOutputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.output`,
nodeData: nodeData,
});
}
@@ -283,7 +378,7 @@ const Properties: FC<PropertiesProps> = ({
label: variable.name,
type: 'variable',
dataType: variable.type,
value: `conversation.${variable.name}`,
value: `conv.${variable.name}`,
nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' },
group: 'CONVERSATION'
});
@@ -387,7 +482,7 @@ const Properties: FC<PropertiesProps> = ({
label: 'context',
type: 'variable',
dataType: 'String',
value: `{{context}}`,
value: `context`,
nodeData: selectedNode.getData(),
isContext: true,
});
@@ -476,7 +571,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
>
<MappingList name={key} />
<MappingList name={key} options={variableList} />
</Form.Item>
)
@@ -583,7 +678,7 @@ const Properties: FC<PropertiesProps> = ({
? <Input.TextArea placeholder={t('common.pleaseEnter')} />
: config.type === 'select'
? <Select
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
placeholder={t('common.pleaseSelect')}
/>
: config.type === 'inputNumber'
@@ -635,7 +730,7 @@ const Properties: FC<PropertiesProps> = ({
}
/>
: config.type === 'switch'
? <Switch />
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_names', []) } : undefined} />
: config.type === 'categoryList'
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
: config.type === 'conditionList'

View File

@@ -39,6 +39,9 @@ import processEvolutionIcon from '@/assets/images/workflow/process_evolution.png
import questionClassifierIcon from '@/assets/images/workflow/question-classifier.png'
import breakIcon from '@/assets/images/workflow/break.png'
import assignerIcon from '@/assets/images/workflow/assigner.png'
import memoryReadIcon from '@/assets/images/workflow/memory-read.png'
import memoryWriteIcon from '@/assets/images/workflow/memory-write.png'
import { memoryConfigListUrl } from '@/api/memory'
import { getModelListUrl } from '@/api/models'
@@ -159,6 +162,7 @@ export const nodeLibrary: NodeLibrary[] = [
},
text: {
type: 'variableList',
filterLoopIterationVars: true
},
params: {
type: 'paramList',
@@ -174,8 +178,7 @@ export const nodeLibrary: NodeLibrary[] = [
{
category: "cognitiveUpgrading",
nodes: [
{
type: "memory-read", icon: memoryEnhancementIcon,
{ type: "memory-read", icon: memoryReadIcon,
config: {
message: {
type: 'messageEditor',
@@ -198,7 +201,7 @@ export const nodeLibrary: NodeLibrary[] = [
}
}
},
{ type: "memory-write", icon: memoryEnhancementIcon,
{ type: "memory-write", icon: memoryWriteIcon,
config: {
message: {
type: 'messageEditor',
@@ -272,6 +275,7 @@ export const nodeLibrary: NodeLibrary[] = [
},
parallel: {
type: 'switch',
defaultValue: false
},
parallel_count: {
type: 'slider',
@@ -284,6 +288,7 @@ export const nodeLibrary: NodeLibrary[] = [
},
flatten: { // 扁平化输出
type: 'switch',
defaultValue: false
},
output: {
type: 'variableList',
@@ -304,6 +309,13 @@ export const nodeLibrary: NodeLibrary[] = [
expressions: []
}
},
max_loop: {
type: 'slider',
min: 1,
max: 100,
step: 1,
defaultValue: 10
},
}
},
{ type: "cycle-start", icon: loopIcon },
@@ -317,7 +329,7 @@ export const nodeLibrary: NodeLibrary[] = [
},
group_names: {
type: 'groupVariableList',
defaultValue: [{ key: 'Group1', value: []}]
defaultValue: [],
}
}
},
@@ -382,7 +394,8 @@ export const nodeLibrary: NodeLibrary[] = [
defaultValue: {}
},
retry: {
type: 'define',
type: 'switch',
defaultValue: false
},
error_handle: {
type: 'define',

View File

@@ -94,9 +94,7 @@ export const useWorkflowGraph = ({
const { group_names, group } = config
nodeLibraryConfig.config[key].defaultValue = group
? Object.entries(group_names as Record<string, any>).map(([key, value]) => ({ key, value }))
: [{ key: 'Group1', value: group_names }]
console.log('group_names', nodeLibraryConfig.config)
: group_names
} else if (nodeLibraryConfig.config && nodeLibraryConfig.config[key] && config[key]) {
nodeLibraryConfig.config[key].defaultValue = config[key]
}
@@ -832,7 +830,7 @@ export const useWorkflowGraph = ({
// 创建干净的节点数据,只保留必要的字段
const cleanNodeData = {
id: `${dragData.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
id: `${dragData.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
name: t(`workflow.${dragData.type}`),
...nodeLibraryConfig
};
@@ -842,6 +840,7 @@ export const useWorkflowGraph = ({
...graphNodeLibrary[dragData.type],
x: point.x - 150,
y: point.y - 100,
id: cleanNodeData.id,
data: { ...cleanNodeData, isGroup: true },
});
} else if (dragData.type === 'if-else') {
@@ -850,6 +849,7 @@ export const useWorkflowGraph = ({
...graphNodeLibrary[dragData.type],
x: point.x - 100,
y: point.y - 60,
id: cleanNodeData.id,
data: { ...cleanNodeData },
});
} else {
@@ -858,6 +858,7 @@ export const useWorkflowGraph = ({
...(graphNodeLibrary[dragData.type] || graphNodeLibrary.default),
x: point.x - 60,
y: point.y - 20,
id: cleanNodeData.id,
data: { ...cleanNodeData },
});
}
@@ -881,7 +882,15 @@ export const useWorkflowGraph = ({
if (data.config) {
Object.keys(data.config).forEach(key => {
if (data.config[key] && 'defaultValue' in data.config[key] && key !== 'knowledge_retrieval') {
if (data.config[key] && 'defaultValue' in data.config[key] && key === 'group_names') {
let group_names = data.config.group.defaultValue ? {} : data.config[key].defaultValue
if (data.config.group.defaultValue) {
data.config[key].defaultValue.map((vo: any) => {
group_names[vo.key] = vo.value
})
}
itemConfig[key] = group_names
} else if (data.config[key] && 'defaultValue' in data.config[key] && key !== 'knowledge_retrieval') {
itemConfig[key] = data.config[key].defaultValue
} else if (key === 'knowledge_retrieval' && data.config[key] && 'defaultValue' in data.config[key]) {
const { knowledge_bases } = data.config[key].defaultValue
@@ -910,7 +919,7 @@ export const useWorkflowGraph = ({
const sourceCell = graphRef.current?.getCellById(edge.getSourceCellId());
const targetCell = graphRef.current?.getCellById(edge.getTargetCellId());
const sourcePortId = edge.getSourcePortId();
// 过滤无效连线源节点或目标节点不存在或者是add-node类型
if (!sourceCell?.getData()?.id || !targetCell?.getData()?.id ||
sourceCell?.getData()?.type === 'add-node' || targetCell?.getData()?.type === 'add-node') {