Merge branch 'release/v0.2.7' into feature/tool_yjp

This commit is contained in:
yujiangping
2026-03-13 10:36:10 +08:00
74 changed files with 3799 additions and 1526 deletions

View File

@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
apt install -y ghostscript && \
apt install -y libmagic1
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@@ -65,7 +65,7 @@ celery_app.conf.update(
# 时区
timezone='Asia/Shanghai',
enable_utc=True,
enable_utc=False,
# 任务追踪
task_track_started=True,
@@ -114,6 +114,7 @@ celery_app.conf.update(
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
},
)

View File

@@ -1,10 +1,12 @@
import uuid
import io
from typing import Optional, Annotated
import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from urllib.parse import quote
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_dsl_service import AppDslService
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -1010,3 +1013,57 @@ def get_workspace_api_statistics(
)
return success(data=result)
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
@cur_workspace_access_guard()
async def export_app(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
release_id: Optional[uuid.UUID] = None
):
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
release_id: 指定发布版本id不传则导出当前草稿配置。
"""
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
encoded = quote(filename, safe=".")
yaml_bytes = yaml_str.encode("utf-8")
file_stream = io.BytesIO(yaml_bytes)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/octet-stream; charset=utf-8",
headers={"Content-Disposition": f"attachment; filename={encoded}",
"Content-Length": str(len(yaml_bytes))}
)
@router.post("/import", summary="从 YAML 文件导入应用")
@cur_workspace_access_guard()
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
raw = (await file.read()).decode("utf-8")
dsl = yaml.safe_load(raw)
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -65,13 +65,18 @@ async def get_mcp_servers(
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or access is denied"
)
# 3. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
body = {
@@ -141,13 +146,18 @@ async def get_operational_mcp_servers(
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or access is denied"
)
# 2. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
url = f'{api.mcp_base_url}/operational'
@@ -199,13 +209,18 @@ async def get_mcp_server(
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or access is denied"
)
# 2. Get detailed information for a specific MCP Server
api = MCPApi()
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
result = api.get_mcp_server(server_id=server_id)
@@ -263,7 +278,7 @@ async def get_mcp_market_config(
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or access is denied"
)
@@ -296,7 +311,7 @@ async def get_mcp_market_config_by_mcp_market_id(
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or access is denied"
)
@@ -325,7 +340,7 @@ async def update_mcp_market_config(
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or you do not have permission to access it"
)
@@ -382,7 +397,7 @@ async def delete_mcp_market_config(
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
status_code=status.HTTP_400_BAD_REQUEST,
detail="The mcp market config does not exist or you do not have permission to access it"
)

View File

@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from typing import Optional
from app.core.response_utils import success
@@ -156,9 +157,13 @@ async def get_workspace_end_users(
"app.tasks.init_implicit_emotions_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发隐性记忆按需初始化任务,候选用户数: {len(end_user_ids)}")
_celery_app.send_task(
"app.tasks.init_interest_distribution_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发隐性记忆按需初始化任务失败(不影响主流程): {e}")
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
@@ -398,14 +403,15 @@ def get_current_user_rag_total_num(
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
page: int = Query(1, gt=0, description="页码从1开始"),
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主知识库中的chunk内容
获取当前宿主知识库中的chunk内容(分页)
"""
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功")
@@ -418,26 +424,18 @@ async def get_chunk_summary_tag(
current_user: User = Depends(get_current_user),
):
"""
获取chunk总结、提取的标签和人物形象
读取RAG摘要、标签和人物形象纯读库不触发生成
返回格式:
{
"summary": "chunk内容的总结",
"tags": [
{"tag": "标签1", "frequency": 5},
{"tag": "标签2", "frequency": 3},
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
"summary": "用户摘要",
"tags": [{"tag": "标签1", "frequency": 5}, ...],
"personas": ["产品设计师", ...],
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk摘要标签人物形象")
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG摘要/标签/人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id,
limit=limit,
@@ -445,9 +443,8 @@ async def get_chunk_summary_tag(
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
return success(data=data, msg="获取成功")
@router.get("/chunk_insight", response_model=ApiResponse)
@@ -458,24 +455,57 @@ async def get_chunk_insight(
current_user: User = Depends(get_current_user),
):
"""
获取chunk的洞察内容
读取RAG洞察报告纯读库不触发生成
返回格式:
{
"insight": "对chunk内容的深度洞察分析"
"insight": "总体概述",
"behavior_pattern": "行为模式",
"key_findings": "关键发现",
"growth_trajectory": "成长轨迹",
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk洞察")
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG洞察")
data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id,
limit=limit,
db=db,
current_user=current_user
)
api_logger.info("成功获取chunk洞察")
return success(data=data, msg="chunk洞察获取成功")
return success(data=data, msg="获取成功")
class GenerateRagProfileRequest(BaseModel):
end_user_id: str = Field(..., description="宿主ID")
limit: int = Field(15, description="参与生成的chunk数量上限")
max_tags: int = Field(10, description="最大标签数量")
@router.post("/generate_rag_profile", response_model=ApiResponse)
async def generate_rag_profile(
body: GenerateRagProfileRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
生产接口为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
每次请求都会重新生成,覆盖已有数据。
"""
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
data = await memory_dashboard_service.generate_rag_profile(
end_user_id=body.end_user_id,
limit=body.limit,
max_tags=body.max_tags,
db=db,
current_user=current_user,
)
api_logger.info(f"RAG画像生产完成: {data}")
return success(data=data, msg="RAG画像生产完成")
@router.get("/dashboard_data", response_model=ApiResponse)

View File

@@ -14,6 +14,7 @@ from app.models import User
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
from app.core.exceptions import BusinessException
router = APIRouter(prefix="/tools", tags=["Tool System"])
@@ -103,7 +104,7 @@ async def create_tool(
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = service.create_tool(
tool_id = await service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
@@ -113,6 +114,8 @@ async def create_tool(
tags=request.tags
)
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except BusinessException as e:
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:

View File

@@ -1,7 +1,6 @@
import json
import os
from pathlib import Path
from typing import Annotated, Any, Dict, Optional
from typing import Annotated, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
@@ -115,6 +114,7 @@ class Settings:
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")

View File

@@ -33,6 +33,7 @@ class DialogExtractionResponse(BaseModel):
- is_related对话与场景的相关性判定。
- times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。
- preserve_keywords情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
"""
is_related: bool = Field(...)
times: List[str] = Field(default_factory=list)
@@ -41,6 +42,7 @@ class DialogExtractionResponse(BaseModel):
contacts: List[str] = Field(default_factory=list)
addresses: List[str] = Field(default_factory=list)
keywords: List[str] = Field(default_factory=list)
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
class MessageImportanceResponse(BaseModel):
@@ -86,26 +88,17 @@ class SemanticPruner:
self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
self.config.pruning_scene,
fallback_to_generic=True
)
# 加载统一填充词库
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
# 判断是否为内置专门场景
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
# 自定义场景的本体类型列表(用于注入提示词)
# 本体类型列表(用于注入提示词,所有场景均支持)
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
if self._is_builtin_scene:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else:
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
@@ -117,107 +110,27 @@ class SemanticPruner:
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
改进版:使用场景特定的模式进行识别
- 根据 pruning_scene 动态加载对应的识别规则
- 支持教育、在线服务、外呼三个场景的特定模式
"""
text = message.msg.strip()
if not text:
return False
# 使用场景特定的模式
all_patterns = (
self.scene_config.high_priority_patterns +
self.scene_config.medium_priority_patterns +
self.scene_config.low_priority_patterns
)
for pattern, _ in all_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
return True
# 检查是否为问句(以问号结尾或包含疑问词)
if text.endswith("") or text.endswith("?"):
return True
# 检查是否包含问句关键词
if any(keyword in text for keyword in self.scene_config.question_keywords):
return True
# 检查是否包含决策性关键词
if any(keyword in text for keyword in self.scene_config.decision_keywords):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
改进版使用场景特定的权重体系0-10分
- 根据场景动态调整不同信息类型的权重
- 高优先级模式4-6分
- 中优先级模式2-3分
- 低优先级模式1分
"""
text = message.msg.strip()
score = 0
# 使用场景特定的权重
for pattern, weight in self.scene_config.high_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.medium_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.low_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
# 问句加分
if text.endswith("") or text.endswith("?"):
score += 2
# 包含问句关键词加分
if any(keyword in text for keyword in self.scene_config.question_keywords):
score += 1
# 包含决策性关键词加分
if any(keyword in text for keyword in self.scene_config.decision_keywords):
score += 2
# 长度加分(较长的消息通常包含更多信息)
if len(text) > 50:
score += 1
if len(text) > 100:
score += 1
return min(score, 10) # 最高10分
# _is_important_message 和 _importance_score 已移除:
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息。
改进版:更严格的填充消息判断,避免误删场景相关内容
满足以下之一视为填充消息
- 纯标点或空白
- 在场景特定填充词库中(精确匹配
- 纯表情符号
- 常见寒暄(精确匹配短语)
注意:不再使用长度判断,避免误删短但重要的消息
判断顺序:
1. 空消息
2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配
4. 纯表情/标点
"""
t = message.msg.strip()
if not t:
return True
# 检查是否在场景特定填充词库中(精确匹配)
if t in self.scene_config.filler_phrases:
return True
# 常见寒暄和问候(精确匹配,避免误删)
common_greetings = {
"在吗", "在不在", "在呢", "在的",
@@ -229,25 +142,11 @@ class SemanticPruner:
}
if t in common_greetings:
return True
# 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t):
return True
# 检查是否为纯emojiUnicode表情
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # 表情符号
"\U0001F300-\U0001F5FF" # 符号和象形文字
"\U0001F680-\U0001F6FF" # 交通和地图符号
"\U0001F1E0-\U0001F1FF" # 旗帜
"\U00002702-\U000027B0"
"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE
)
if emoji_pattern.fullmatch(t):
return True
# 纯标点符号
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True
@@ -432,14 +331,12 @@ class SemanticPruner:
rendered = self.template.render(
pruning_scene=self.config.pruning_scene,
is_builtin_scene=self._is_builtin_scene,
ontology_classes=self._ontology_classes,
dialog_text=dialog_text,
language=self.language
)
log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene,
"is_builtin_scene": self._is_builtin_scene,
"ontology_classes_count": len(self._ontology_classes),
"language": self.language
})
@@ -504,62 +401,56 @@ class SemanticPruner:
# 相关对话不剪枝
return dialog
# 在不相关对话中,识别重要/不重要消息
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords
)
msgs = dialog.context.msgs
imp_unrel_msgs: List[ConversationMessage] = []
unimp_unrel_msgs: List[ConversationMessage] = []
# 分类:填充 / 其他可删LLM保护消息通过不加入任何桶来隐式保护
filler_ids: set = set()
deletable: List[ConversationMessage] = []
for m in msgs:
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
imp_unrel_msgs.append(m)
if self._msg_matches_tokens(m, preserve_tokens):
pass # 保护消息:不加入任何桶,不会被删除
elif self._is_filler_message(m):
filler_ids.add(id(m))
else:
unimp_unrel_msgs.append(m)
# 计算总删除目标数量
deletable.append(m)
# 计算删除目标
total_unrel = len(msgs)
delete_target = int(total_unrel * proportion)
if proportion > 0 and total_unrel > 0 and delete_target == 0:
delete_target = 1
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
unimp_del_cap = len(unimp_unrel_msgs)
max_capacity = max(0, len(msgs) - 1)
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
delete_target = min(delete_target, max_deletable)
# 删除配额分配
del_unimp = min(delete_target, unimp_del_cap)
rem = delete_target - del_unimp
del_imp = min(rem, imp_del_cap)
# 选取删除集合
unimp_delete_ids = []
imp_delete_ids = []
if del_unimp > 0:
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
if del_imp > 0:
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
# 统计实际删除数量(重要/不重要)
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept_msgs = []
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
# 优先删填充,再删其他可删消息(按出现顺序)
to_delete_ids: set = set()
for m in msgs:
mid = id(m)
if mid in delete_targets:
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
continue
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
actual_imp_deleted += 1
continue
kept_msgs.append(m)
if len(to_delete_ids) >= delete_target:
break
if id(m) in filler_ids:
to_delete_ids.add(id(m))
for m in deletable:
if len(to_delete_ids) >= delete_target:
break
to_delete_ids.add(id(m))
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
deleted_total = actual_unimp_deleted + actual_imp_deleted
deleted_total = len(msgs) - len(kept_msgs)
protected_count = len(msgs) - len(filler_ids) - len(deletable)
self._log(
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
)
dialog.context = ConversationContext(msgs=kept_msgs)
@@ -594,51 +485,64 @@ class SemanticPruner:
result: List[DialogData] = []
total_original_msgs = 0
total_deleted_msgs = 0
for d_idx, dd in enumerate(dialogs):
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
semaphore = asyncio.Semaphore(self.max_concurrent)
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
async with semaphore:
try:
return await self._extract_dialog_important(dd.content)
except Exception as e:
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
return DialogExtractionResponse(is_related=True)
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
msgs = dd.context.msgs
original_count = len(msgs)
total_original_msgs += original_count
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
# qa_pairs = self._identify_qa_pairs(msgs)
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
# ========================================================
# 消息级分类:每条消息独立判断
important_msgs = [] # 重要消息(保留)
unimportant_msgs = [] # 不重要消息(可删除)
filler_msgs = [] # 填充消息(优先删除)
# 判断是否需要详细日志仅对前N条消息记录
# 从 LLM 抽取结果中获取所有需要保留的 token
preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
)
# 判断是否需要详细日志
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
if extraction.preserve_keywords:
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
# 消息级分类LLM保护 / 填充 / 其他可删
llm_protected_msgs = [] # LLM 保护消息preserve_tokens 命中):绝对不可删除
filler_msgs = [] # 填充消息(优先删除)
deletable_msgs = [] # 其余消息(按比例删除)
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
# ========== 问答对保护判断(已注释) ==========
# if idx in protected_indices:
# important_msgs.append((idx, m))
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
# ==========================================
# 填充消息(寒暄、表情等)
if self._is_filler_message(m):
if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护LLM不可删")
elif self._is_filler_message(m):
filler_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
# 重要信息(学号、成绩、时间、金额等)
elif self._is_important_message(m):
important_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
# 其他消息
else:
unimportant_msgs.append((idx, m))
deletable_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...'不重要")
self._log(f" [{idx}] '{msg_text[:30]}...'可删")
# important_msgs 仅用于日志统计
important_msgs = llm_protected_msgs
# 计算删除配额
delete_target = int(original_count * proportion)
@@ -649,37 +553,23 @@ class SemanticPruner:
max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable)
# 删除策略:优先删填充消息,再删除不重要消息
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
to_delete_indices = set()
deleted_details = [] # 记录删除的消息详情
deleted_details = []
# 第一步:删除填充消息
filler_to_delete = min(len(filler_msgs), delete_target)
for i in range(filler_to_delete):
idx, msg = filler_msgs[i]
for idx, msg in filler_msgs:
if len(to_delete_indices) >= delete_target:
break
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
# 第二步:如果还需要删除,删除不重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0:
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
for i in range(unimp_to_delete):
idx, msg = unimportant_msgs[i]
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
# 第三步:如果还需要删除,按重要性分数删除重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0 and important_msgs:
# 按重要性分数排序(分数低的优先删除)
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
imp_to_delete = min(len(imp_sorted), remaining_quota)
for i in range(imp_to_delete):
idx, msg = imp_sorted[i]
to_delete_indices.add(idx)
score = self._importance_score(msg)
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
# 第二步:如果还需要删除,按出现顺序删可删消息
for idx, msg in deletable_msgs:
if len(to_delete_indices) >= delete_target:
break
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
# 执行删除
kept_msgs = []
@@ -707,7 +597,7 @@ class SemanticPruner:
self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}"
)

View File

@@ -1,66 +1,25 @@
"""
场景特定配置 - 为不同场景提供定制化的剪枝规则
场景特定配置 - 统一填充词库
功能:
- 场景特定的重要信息识别模式
- 场景特定的重要性评分权重
- 场景特定的填充词库
- 场景特定的问答对识别规则
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
本模块仅保留统一填充词库filler_phrases用于识别无意义寒暄/表情/口头禅。
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
"""
from typing import Dict, List, Set, Tuple
from typing import List, Set
from dataclasses import dataclass, field
@dataclass
class ScenePatterns:
"""场景特定的识别模式"""
# 重要信息的正则模式(优先级从高到低)
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
# 填充词库(无意义对话)
"""场景特定的识别模式(仅保留填充词库)"""
filler_phrases: Set[str] = field(default_factory=set)
# 问句关键词(用于识别问答对)
question_keywords: Set[str] = field(default_factory=set)
# 决策性/承诺性关键词
decision_keywords: Set[str] = field(default_factory=set)
class SceneConfigRegistry:
"""场景配置注册表 - 管理所有场景的特定配置"""
# 基础通用模式(所有场景共享)
BASE_HIGH_PRIORITY = [
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
(r"金额|费用|价格|¥|¥|\d+元", 5),
(r"\d{11}", 4), # 手机号
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
]
BASE_MEDIUM_PRIORITY = [
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"电话|手机号|微信|QQ|联系方式", 3),
(r"地址|地点|位置", 2),
(r"时间|日期|有效期|截止", 2),
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
(r"今年|去年|明年", 3),
]
BASE_LOW_PRIORITY = [
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
(r"AM|PM|am|pm", 1),
]
BASE_FILLERS = {
"""场景配置注册表 - 所有场景共用统一填充词库"""
BASE_FILLERS: Set[str] = {
# 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼",
# 确认词
"是的", "", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
"是的", "", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
# 服务类套话
"请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "欢迎下次咨询",
# 外呼套话
"", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
# 教育场景套话
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
# 标点和符号
"。。。", "...", "???", "", "!!!", "",
# 表情符号
@@ -81,246 +59,8 @@ class SceneConfigRegistry:
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
"emmm", "emm", "em", "mmp", "wtf", "omg",
}
BASE_QUESTION_KEYWORDS = {
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时", ""
}
BASE_DECISION_KEYWORDS = {
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
"承诺", "保证", "确保", "负责", "同意", "答应"
}
@classmethod
def get_education_config(cls) -> ScenePatterns:
"""教育场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 成绩相关(最高优先级)
(r"成绩|分数|得分|满分|及格|不及格", 6),
(r"GPA|绩点|学分|平均分", 6),
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
# 学籍信息
(r"学号|学生证|教师工号|工号", 5),
(r"班级|年级|专业|院系", 4),
# 课程相关
(r"课程|科目|学科|必修|选修", 4),
(r"教材|课本|教科书|参考书", 4),
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
# 学科内容(新增)
(r"微积分|导数|积分|函数|极限|微分", 4),
(r"代数|几何|三角|概率|统计", 4),
(r"物理|化学|生物|历史|地理", 4),
(r"英语|语文|数学|政治|哲学", 4),
(r"定义|定理|公式|概念|原理|法则", 3),
(r"例题|解题|证明|推导|计算", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 教学活动
(r"作业|练习|习题|题目", 3),
(r"考试|测验|测试|考核|期中|期末", 3),
(r"上课|下课|课堂|讲课", 2),
(r"提问|回答|发言|讨论", 2),
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
# 时间安排
(r"课表|课程表|时间表", 3),
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"老师|教师|同学|学生", 1),
(r"教室|实验室|图书馆", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"为啥", "", "咋办", "怎样", "如何做",
"能不能", "可不可以", "行不行", "对不对", "是不是",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"必考", "重点", "考点", "难点", "关键",
"记住", "背诵", "掌握", "理解", "复习",
}
)
@classmethod
def get_online_service_config(cls) -> ScenePatterns:
"""在线服务场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 工单相关(最高优先级)
(r"工单号|工单编号|ticket|TK\d+", 6),
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
# 产品信息
(r"产品型号|型号|SKU|产品编号", 5),
(r"序列号|SN|设备号", 5),
(r"版本号|软件版本|固件版本", 4),
# 问题描述
(r"故障|错误|异常|bug|问题", 4),
(r"错误代码|故障代码|error code", 5),
(r"无法|不能|失败|报错", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 服务相关
(r"退款|退货|换货|补发", 4),
(r"发票|收据|凭证", 3),
(r"物流|快递|运单号", 3),
(r"保修|质保|售后", 3),
# 时效相关
(r"SLA|响应时间|处理时长", 4),
(r"超时|延迟|等待", 2),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"客服|工程师|技术支持", 1),
(r"用户|客户|会员", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 在线服务特有填充词
"您好", "请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "再见", "欢迎下次咨询",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"能否", "可否", "是否", "有没有", "能不能",
"怎么办", "如何处理", "怎么解决",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"立即处理", "马上解决", "尽快", "优先",
"升级", "转接", "派单", "跟进",
"补偿", "赔偿", "退款", "换货",
}
)
@classmethod
def get_outbound_config(cls) -> ScenePatterns:
"""外呼场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 意向相关(最高优先级)
(r"意向|意愿|兴趣|感兴趣", 6),
(r"A类|B类|C类|D类|高意向|低意向", 6),
(r"成交|签约|下单|购买|确认", 6),
# 联系信息(外呼场景中更重要)
(r"预约|约定|安排|确定时间", 5),
(r"下次联系|回访|跟进", 5),
(r"方便|有空|可以|时间", 4),
# 通话状态
(r"接通|未接通|占线|关机|停机", 4),
(r"通话时长|通话时间", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 客户信息
(r"姓名|称呼|先生|女士", 3),
(r"公司|单位|职位|职务", 3),
(r"需求|要求|期望", 3),
# 跟进状态
(r"跟进状态|进展|进度", 3),
(r"已联系|待联系|联系中", 2),
(r"拒绝|不感兴趣|考虑|再说", 3),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"销售|客户经理|业务员", 1),
(r"产品|服务|方案", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 外呼场景特有填充词
"您好", "", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"好的", "", "可以", "没问题", "那就这样",
"再联系", "回头聊", "有需要再说",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"有没有", "需不需要", "要不要", "考虑不考虑",
"了解吗", "知道吗", "听说过吗",
"方便吗", "有空吗", "在吗",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"确定", "决定", "选择", "购买", "下单",
"预约", "安排", "约定", "确认",
"跟进", "回访", "联系", "沟通",
}
)
@classmethod
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
"""根据场景名称获取配置
Args:
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
fallback_to_generic: 如果场景不存在,是否降级到通用配置
Returns:
对应场景的配置,如果场景不存在:
- fallback_to_generic=True: 返回通用配置(仅基础规则)
- fallback_to_generic=False: 抛出异常
"""
scene_map = {
'education': cls.get_education_config,
'online_service': cls.get_online_service_config,
'outbound': cls.get_outbound_config,
}
if scene in scene_map:
return scene_map[scene]()
if fallback_to_generic:
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
return cls.get_generic_config()
else:
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
@classmethod
def get_generic_config(cls) -> ScenePatterns:
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
low_priority_patterns=cls.BASE_LOW_PRIORITY,
filler_phrases=cls.BASE_FILLERS,
question_keywords=cls.BASE_QUESTION_KEYWORDS,
decision_keywords=cls.BASE_DECISION_KEYWORDS
)
@classmethod
def get_all_scenes(cls) -> List[str]:
"""获取所有预定义场景的列表"""
return ['education', 'online_service', 'outbound']
@classmethod
def is_scene_supported(cls, scene: str) -> bool:
"""检查场景是否有专门的配置支持
Args:
scene: 场景名称
Returns:
True: 有专门配置
False: 将使用通用配置
"""
return scene in cls.get_all_scenes()
def get_config(cls, scene: str = "") -> ScenePatterns:
"""所有场景统一返回同一份填充词库"""
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)

View File

@@ -1,6 +1,6 @@
{#
对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
输入pruning_scene, ontology_classes, dialog_text, language
输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
@@ -9,64 +9,71 @@
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等
- addresses: [string],地址/地点相关文本
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
要求:
- 必须只输出上述 JSON且键名一致不得输出解释、前后缀不得包含注释。
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
- 仅输出上述键;避免多余解释或字段。
#}
{# ── 内置场景的固定说明 ── #}
{% set builtin_scene_instructions = {
'education': {
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
},
'online_service': {
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
},
'outbound': {
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
}
} %}
{# ── 确定最终使用的场景说明 ── #}
{% if is_builtin_scene %}
{# 内置专门场景:使用固定说明 #}
{% set scene_key = pruning_scene %}
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
{% set custom_types_str = '' %}
{% else %}
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
{% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %}
{% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %}
{% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %}
{# ── 确定场景说明 ── #}
{% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %}
{% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %}
{# 无本体类型时退化为通用说明 #}
{% if language == 'en' %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% else %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %}
{% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %}
{% else %}
{% if language == 'en' %}
{% set custom_types_str = '' %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% else %}
{% set custom_types_str = '' %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %}
{% endif %}
{% if language == "zh" %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务
1. 判断对话是否与指定场景相关;
2. 从对话中抽取所有需要保留的重要信息片段。
场景说明:{{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
{% if custom_types_str %}
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }}相关的内容即判定为相关is_related=true
{% endif %}
---
【必须保留的内容(不可删除)】
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
- 时间信息:日期、时间点、时间段、有效期 → times 字段
- 编号信息学号、工号、订单号、申请号、账号、ID → ids 字段
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
- 联系方式电话、手机号、邮箱、微信、QQ → contacts 字段
- 地址信息:地点、地址、位置 → addresses 字段
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
【可以删除的内容】
以下类型的内容属于低价值信息,可以在剪枝时删除:
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
例如:
- "我好开心呀" → 包含情绪开心必须保留preserve_keywords 中加入"开心"
- "好喜欢打羽毛球呀" → 包含兴趣爱好喜欢打羽毛球必须保留preserve_keywords 中加入"喜欢打羽毛球"
- "我好难过" → 包含情绪难过必须保留preserve_keywords 中加入"难过"
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪必须保留preserve_keywords 中加入"开心"
---
对话全文:
"""
{{ dialog_text }}
@@ -80,15 +87,46 @@
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
"keywords": [<string>...],
"preserve_keywords": [<string>...]
}
{% else %}
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
1. Determine whether the dialogue is relevant to the specified scene;
2. Extract all important information fragments that must be preserved.
Scenario Description: {{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
{% if custom_types_str %}
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
{% endif %}
---
[MUST PRESERVE (cannot be deleted)]
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
- Time information: dates, time points, durations, expiry dates → times field
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
- Address information: locations, addresses, places → addresses field
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
[CAN BE DELETED]
The following types of content are low-value and can be removed during pruning:
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
Examples:
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
---
Full Dialogue:
"""
{{ dialog_text }}
@@ -102,6 +140,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
"keywords": [<string>...],
"preserve_keywords": [<string>...]
}
{% endif %}

View File

@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
from .chunk_summary import generate_chunk_summary
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
from .chunk_insight import generate_chunk_insight
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
__all__ = [
"generate_chunk_summary",
"extract_chunk_tags",
"extract_chunk_persona",
"generate_chunk_insight",
"generate_chunk_insight_sections",
]

View File

@@ -1,213 +1,207 @@
"""
Generate insights from RAG chunks.
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
This module provides functionality to analyze chunk content and generate insights using LLM.
The memory_insight.jinja2 template produces a four-section report:
【总体概述】 → memory_insight
【行为模式】 → behavior_pattern
【关键发现】 → key_findings
【成长轨迹】 → growth_trajectory
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
"""
import asyncio
import os
import re
from collections import Counter
from typing import Any, Dict, List
from typing import Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
# ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ChunkInsight(BaseModel):
"""Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析")
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
async def _classify_domain(chunk: str, llm_client) -> str:
"""Classify a single chunk into a domain category."""
from pydantic import BaseModel, Field
class DomainClassification(BaseModel):
"""Pydantic model for domain classification."""
domain: str = Field(
...,
description="内容所属的领域分类",
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
)
class _Domain(BaseModel):
domain: str = Field(..., description="领域分类")
async def classify_chunk_domain(chunk: str) -> str:
"""
Classify a chunk into a specific domain.
Args:
chunk: Chunk content string
Returns:
Domain name
"""
try:
llm_client = _get_llm_client()
prompt = f"""请将以下文本内容归类到最合适的领域中。
可选领域及其关键词:
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
- 健康:医疗、养生、运动、心理、保健、疾病等
- 其他:无法归入以上类别的内容
文本内容: {chunk[:500]}...
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
{"role": "user", "content": prompt}
]
classification = await llm_client.response_structured(
messages=messages,
response_model=DomainClassification
prompt = (
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
)
return classification.domain if classification else "其他"
except Exception as e:
business_logger.error(f"分类chunk领域失败: {str(e)}")
result = await llm_client.response_structured(
messages=[{"role": "user", "content": prompt}],
response_model=_Domain,
)
return result.domain if result else "其他"
except Exception:
return "其他"
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
async def _build_insight_inputs(
chunks: List[str],
max_chunks: int,
end_user_id: Optional[str],
) -> Dict[str, Optional[str]]:
"""
Analyze the domain distribution of chunks.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to analyze
Returns:
Dictionary of domain -> percentage
Derive domain_distribution, active_periods, social_connections strings
to feed into the memory_insight.jinja2 template.
"""
if not chunks:
return {}
try:
# 限制分析的chunk数量
chunks_to_analyze = chunks[:max_chunks]
# 为每个chunk分类
domain_counts = Counter()
for chunk in chunks_to_analyze:
domain = await classify_chunk_domain(chunk)
domain_counts[domain] += 1
# 计算百分比
total = sum(domain_counts.values())
domain_distribution = {
domain: count / total
for domain, count in domain_counts.items()
}
# 按百分比降序排序
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
except Exception as e:
business_logger.error(f"分析领域分布失败: {str(e)}")
return {}
llm_client = _get_llm_client(end_user_id)
chunks_sample = chunks[:max_chunks]
# Domain distribution
domain_counts: Counter = Counter()
for chunk in chunks_sample:
domain = await _classify_domain(chunk, llm_client)
domain_counts[domain] += 1
total = sum(domain_counts.values()) or 1
domain_distribution = ", ".join(
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
)
return {
"domain_distribution": domain_distribution,
"active_periods": None, # RAG模式暂无时间维度数据
"social_connections": None, # RAG模式暂无社交关联数据
}
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
# ── Section parser ────────────────────────────────────────────────────────────
_ZH_SECTIONS = {
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
"key_findings": r"【关键发现】(.*?)(?=【|$)",
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
}
_EN_SECTIONS = {
"memory_insight": r"【Overview】(.*?)(?=【|$)",
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
}
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
"""Extract the four sections from the LLM output."""
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
result = {}
for key, pattern in patterns.items():
match = re.search(pattern, text, re.DOTALL)
result[key] = match.group(1).strip() if match else ""
return result
# ── Public API ────────────────────────────────────────────────────────────────
async def generate_chunk_insight(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
"""
Generate insights from the given chunks.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to analyze
Returns:
A comprehensive insight report
Generate a memory insight report from RAG chunks.
Returns the full raw report text (suitable for end_user.memory_insight).
Use generate_chunk_insight_sections() when you need all four dimensions.
"""
sections = await generate_chunk_insight_sections(
chunks=chunks,
max_chunks=max_chunks,
end_user_id=end_user_id,
language=language,
)
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
async def generate_chunk_insight_sections(
chunks: List[str],
max_chunks: int = 15,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> Dict[str, str]:
"""
Generate a four-section memory insight report from RAG chunks.
Returns a dict with keys:
memory_insight, behavior_pattern, key_findings, growth_trajectory
(plus '_raw' containing the full LLM output for debugging)
"""
if not chunks:
business_logger.warning("没有提供chunk内容用于生成洞察")
return "暂无足够数据生成洞察报告"
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "暂无足够数据生成洞察报告"
return empty
try:
# 1. 分析领域分布
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
# 2. 统计基本信息
total_chunks = len(chunks)
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
# 3. 构建洞察prompt
prompt_parts = []
if domain_dist:
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
prompt_parts.append(f"- 内容领域分布: {top_domains}")
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}")
# 添加部分chunk内容作为参考
sample_chunks = chunks[:5]
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
prompt_parts.append(f"\n内容示例:\n{sample_content}")
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
重要规则:
1. 报告需要将所有要点流畅地串联成一个段落
2. 语言风格要专业、客观,同时易于理解
3. 不要添加任何额外的解释或标题,直接输出报告内容
4. 基于提供的数据和示例内容进行分析,不要编造信息
5. 重点关注内容的主题、特点和价值
6. 报告长度控制在150-200字
# Build template inputs from chunk analysis
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
例如,如果输入是:
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
- 内容规模: 共50个知识片段平均长度320字
内容示例: [示例内容...]
rendered_prompt = await render_memory_insight_prompt(
domain_distribution=inputs["domain_distribution"],
active_periods=inputs["active_periods"],
social_connections=inputs["social_connections"],
language=language,
)
你的输出应该类似:
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段平均每个片段约320字内容详实。从示例来看内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
"""
user_prompt = "\n".join(prompt_parts)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 调用LLM生成洞察
llm_client = _get_llm_client()
messages = [{"role": "user", "content": rendered_prompt}]
llm_client = _get_llm_client(end_user_id)
response = await llm_client.chat(messages=messages)
insight = response.content.strip()
business_logger.info(f"成功生成chunk洞察分析了 {min(len(chunks), max_chunks)} 个片段")
return insight
raw_text = response.content.strip() if response and response.content else ""
sections = _parse_sections(raw_text, language=language)
sections["_raw"] = raw_text
business_logger.info(
f"成功生成chunk洞察四维度分析了 {min(len(chunks), max_chunks)} 个片段"
)
return sections
except Exception as e:
business_logger.error(f"生成chunk洞察失败: {str(e)}")
return "洞察生成失败"
if __name__ == "__main__":
# 测试代码
test_chunks = [
"Python是一种高级编程语言以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
]
print("开始生成chunk洞察...")
insight = asyncio.run(generate_chunk_insight(test_chunks))
print(f"\n生成的洞察:\n{insight}")
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
empty["_raw"] = "洞察生成失败"
return empty

View File

@@ -1,11 +1,10 @@
"""
Generate summary for RAG chunks.
This module provides functionality to summarize chunk content using LLM.
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
"""
import asyncio
from typing import Any, Dict, List
import os
from typing import List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
class ChunkSummary(BaseModel):
"""Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要")
# ── Schema ──────────────────────────────────────────────────────────────────
class MemorySummaryStatement(BaseModel):
"""Single labelled statement extracted by memory_summary.jinja2."""
statement: str = Field(..., description="提取的陈述内容")
label: Optional[str] = Field(None, description="陈述标签")
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
class MemorySummaryResponse(BaseModel):
"""
Generate a summary for the given chunks.
Structured output expected from memory_summary.jinja2.
The template asks for a JSON array of labelled statements;
we wrap it in an object so response_structured can parse it.
"""
statements: List[MemorySummaryStatement] = Field(
default_factory=list,
description="从chunk中提取的陈述列表"
)
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
# ── LLM client helper ────────────────────────────────────────────────────────
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
# ── Core function ─────────────────────────────────────────────────────────────
async def generate_chunk_summary(
chunks: List[str],
max_chunks: int = 10,
end_user_id: Optional[str] = None,
language: str = "zh",
) -> str:
"""
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
The template extracts labelled statements from the chunks; we then join them
into a coherent summary string that can be stored in end_user.user_summary.
Args:
chunks: List of chunk content strings
max_chunks: Maximum number of chunks to process (default: 10)
max_chunks: Maximum number of chunks to process
end_user_id: Optional end-user ID for model selection
language: Output language ("zh" or "en")
Returns:
A concise summary of the chunks
Summary string (joined statements or fallback text)
"""
if not chunks:
business_logger.warning("没有提供chunk内容用于生成摘要")
return "暂无内容"
try:
# 限制处理的chunk数量避免token过多
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
chunks_to_process = chunks[:max_chunks]
# 合并chunk内容
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
# 构建prompt
system_prompt = (
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
"- 摘要长度控制在100-150字\n"
"- 提取核心信息和关键要点;\n"
"- 使用客观、清晰的语言;\n"
"- 避免冗余和重复;\n"
"- 如果内容涉及多个主题,按重要性排序呈现。"
chunk_texts = "\n\n".join(
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
)
json_schema = MemorySummaryResponse.model_json_schema()
rendered_prompt = await render_memory_summary_prompt(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=200,
language=language,
)
messages = [{"role": "user", "content": rendered_prompt}]
llm_client = _get_llm_client(end_user_id)
# Try structured output; fall back to plain chat only for LLMClientException
# (indicates the model/provider doesn't support structured output).
# All other exceptions are re-raised so config/schema errors stay visible.
try:
response: MemorySummaryResponse = await llm_client.response_structured(
messages=messages,
response_model=MemorySummaryResponse,
)
if response.summary:
summary = response.summary.strip()
elif response.statements:
summary = "".join(s.statement for s in response.statements)
else:
summary = "暂无内容"
except Exception as e:
from app.core.memory.llm_tools.llm_client import LLMClientException
if isinstance(e, LLMClientException):
business_logger.warning(
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
)
raw = await llm_client.chat(messages=messages)
summary = raw.content.strip() if raw and raw.content else "暂无内容"
else:
business_logger.error(f"生成摘要时发生非预期异常: {e}")
raise
business_logger.info(
f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段"
)
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# 调用LLM生成摘要
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
summary = response.content.strip()
business_logger.info(f"成功生成chunk摘要处理了 {len(chunks_to_process)} 个片段")
return summary
except Exception as e:
business_logger.error(f"生成chunk摘要失败: {str(e)}")
return "摘要生成失败"
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
"""
Generate summaries for multiple chunk lists in batch.
Args:
chunks_list: List of chunk lists
Returns:
List of summaries
"""
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
return await asyncio.gather(*tasks)
if __name__ == "__main__":
# 测试代码
test_chunks = [
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
"第二段内容介绍了深度学习的应用场景和发展历史。",
"第三段讨论了自然语言处理技术的最新进展。"
]
print("开始生成chunk摘要...")
summary = asyncio.run(generate_chunk_summary(test_chunks))
print(f"\n生成的摘要:\n{summary}")

View File

@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
"""
import asyncio
import os
from collections import Counter
from typing import List, Tuple
from typing import List, Optional, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ExtractedTags(BaseModel):
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师''旅行爱好者'")
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
"""
Extract meaningful tags from the given chunks.
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
)
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
# 为每个chunk单独提取标签然后统计频率
all_tags = []
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
"""
Extract persona (人物形象) from the given chunks.
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
]
# 调用LLM提取人物形象
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
structured_response = await llm_client.response_structured(
messages=messages,
response_model=ExtractedPersona

View File

@@ -85,6 +85,7 @@ class StorageFactory:
access_key_id=settings.S3_ACCESS_KEY_ID,
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
bucket_name=settings.S3_BUCKET_NAME,
endpoint_url=settings.S3_ENDPOINT_URL,
)
else:

View File

@@ -35,6 +35,19 @@ class S3Storage(StorageBackend):
bucket_name: The name of the S3 bucket.
region: The AWS region.
"""
AMAZON_S3_ENDPOINT_MAP = {
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
"us-east-2": "https://s3.us-east-2.amazonaws.com",
"us-west-1": "https://s3.us-west-1.amazonaws.com",
"us-west-2": "https://s3.us-west-2.amazonaws.com",
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
# 可根据需要扩展其他地域
}
def __init__(
self,
@@ -42,6 +55,7 @@ class S3Storage(StorageBackend):
access_key_id: str,
secret_access_key: str,
bucket_name: str,
endpoint_url: Optional[str] = None
):
"""
Initialize the S3Storage backend.
@@ -51,6 +65,7 @@ class S3Storage(StorageBackend):
access_key_id: The AWS access key ID.
secret_access_key: The AWS secret access key.
bucket_name: The name of the S3 bucket.
endpoint_url: The complete URL to use for the constructed client.
Raises:
StorageConfigError: If any required configuration is missing.
@@ -69,10 +84,19 @@ class S3Storage(StorageBackend):
self.region = region
self.bucket_name = bucket_name
if not endpoint_url:
# 优先匹配内置映射表(解决特殊地域)
if region in self.AMAZON_S3_ENDPOINT_MAP:
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
# 兜底:通用拼接(适配未配置的新地域)
else:
endpoint_url = f"https://s3.{region}.amazonaws.com"
try:
self.client = boto3.client(
"s3",
region_name=region,
endpoint_url=endpoint_url,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
)

View File

@@ -4,65 +4,145 @@
# @Time : 2026/2/25 14:11
from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import (
PlatformMetadata,
PlatformType,
BasePlatformAdapter,
WorkflowParserResult
)
from app.schemas.workflow_schema import ExecutionConfig
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
logger = get_logger()
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
class MemoryBearAdapter(BasePlatformAdapter):
NODE_TYPE_MAPPING = {}
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
def __init__(self, config: dict[str, Any]):
MemoryBearConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
@property
def origin_nodes(self):
return self.config.get("workflow").get("nodes")
return self.config.get("workflow").get("nodes") or []
@property
def origin_edges(self):
return self.config.get("workflow").get("edges")
return self.config.get("workflow").get("edges") or []
@property
def origin_variables(self):
return self.config.get("workflow").get("variables")
return self.config.get("workflow").get("variables") or []
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.MEMORY_BEAR,
version="0.2.5",
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
support_node_types=list(VALID_NODE_TYPES)
)
def map_node_type(self, platform_node_type) -> str:
return platform_node_type
def map_node_type(self, platform_node_type: str) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@staticmethod
def _valid_nodes(node: dict[str, Any]):
if "type" not in node["data"]:
return False
def _valid_node(node: dict[str, Any]) -> bool:
if "id" not in node or "type" not in node:
return False
if not isinstance(node.get("config"), dict):
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
for node in self.origin_nodes:
if not self._valid_nodes(node):
if not self._valid_node(node):
return False
return True
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_id = node.get("id")
node_name = node.get("name")
try:
node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType(
node_id=node_id,
node_type=node["type"]
))
return None
config = node.get("config") or {}
converter = self.get_node_convert(node_type)
converter(node_id, node_name, config) # validates and appends errors if invalid
return NodeDefinition(**node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
node_name=node_name,
detail=f"convert node error - {e}"
))
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
return None
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found"
))
return None
return EdgeDefinition(**edge)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}"
))
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
try:
return VariableDefinition(**variable)
except Exception as e:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}"
))
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
return None
def parse_workflow(self) -> WorkflowParserResult:
self.nodes = self.origin_nodes
self.edges = self.origin_edges
self.conv_variables = self.origin_variables
for node in self.origin_nodes:
converted = self._convert_node(node)
if converted:
self.nodes.append(converted)
valid_node_ids = {n.id for n in self.nodes}
for edge in self.origin_edges:
converted = self._convert_edge(edge, valid_node_ids)
if converted:
self.edges.append(converted)
for variable in self.origin_variables:
converted = self._convert_variable(variable)
if converted:
self.conv_variables.append(converted)
return WorkflowParserResult(
success=True,
success=not self.errors and not self.warnings,
platform=self.get_metadata(),
execution_config=ExecutionConfig(),
origin_config=self.config,
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors,
)

View File

@@ -0,0 +1,85 @@
# -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import (
StartNodeConfig,
EndNodeConfig,
LLMNodeConfig,
AgentNodeConfig,
IfElseNodeConfig,
KnowledgeRetrievalNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
HttpRequestNodeConfig,
JinjaRenderNodeConfig,
VariableAggregatorNodeConfig,
ParameterExtractorNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
QuestionClassifierNodeConfig,
ToolNodeConfig,
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
class MemoryBearConverter(BaseConverter):
errors: list
warnings: list
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
NodeType.START: StartNodeConfig,
NodeType.END: EndNodeConfig,
NodeType.ANSWER: EndNodeConfig,
NodeType.LLM: LLMNodeConfig,
NodeType.AGENT: AgentNodeConfig,
NodeType.IF_ELSE: IfElseNodeConfig,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
NodeType.ASSIGNER: AssignerNodeConfig,
NodeType.CODE: CodeNodeConfig,
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
NodeType.JINJARENDER: JinjaRenderNodeConfig,
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
NodeType.LOOP: LoopNodeConfig,
NodeType.ITERATION: IterationNodeConfig,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
NodeType.TOOL: ToolNodeConfig,
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
}
@staticmethod
def _convert_file(var):
return None
@staticmethod
def _convert_array_file(var):
return []
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
try:
return config_cls.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
def get_node_convert(self, node_type: NodeType):
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
if not config_cls:
return lambda node_id, node_name, config: config
def validate(node_id: str, node_name: str, config: dict):
self.config_validate(node_id, node_name, config_cls, config)
return config
return validate

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from functools import cached_property
@@ -643,15 +644,18 @@ class BaseNode(ABC):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
file_obj = FileInput(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
origin_file_type=content.origin_file_type,
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
)
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
[file_obj]
)
content.set_content(file_obj.get_content())
if message:
content.content_cache[provider] = message
return message

View File

@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
from app.core.workflow.variable.base_variable import FileObject
class HttpAuthConfig(BaseModel):
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
description="Http response headers"
)
files: list[FileObject] = Field(
default_factory=list,
description="List of files",
)
output: str = Field(
default="SUCCESS",
description="HTTP response body",

View File

@@ -1,24 +1,146 @@
import asyncio
import json
import logging
import mimetypes
import uuid
import imghdr
from email.message import Message
from typing import Any, Callable, Coroutine
import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout
import magic
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.utils.file_processer import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod
logger = logging.getLogger(__file__)
class HttpResponse:
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers)
self._is_file: bool | None = None
@property
def content_type(self) -> str:
return self.headers.get("content-type", "")
@property
def content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()
msg["content-disposition"] = content_disposition
return msg
return None
@property
def is_file(self) -> bool:
if self._is_file is not None:
return self._is_file
content_type = self.content_type.split(";")[0].strip().lower()
parsed_content_disposition = self.content_disposition
if parsed_content_disposition:
disp_type = parsed_content_disposition.get_content_disposition()
filename = parsed_content_disposition.get_filename()
if disp_type == "attachment" or filename:
self._is_file = True
return True
if content_type.startswith("text/") and "csv" not in content_type:
return False
if content_type.startswith("application/"):
if any(
text_type in content_type
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
):
self._is_file = False
return False
try:
content_sample = self.response.content[:1024]
content_sample.decode("utf-8")
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
if any(marker in content_sample for marker in text_markers):
return False
except UnicodeDecodeError:
self._is_file = True
return True
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
if main_type:
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
return self._is_file
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
return self._is_file
@property
def is_image(self):
if self.is_file:
kind = imghdr.what(None, h=self.response.content)
return kind is not None
return False
@property
def url(self) -> str:
return str(self.response.url)
@property
def body(self) -> str:
if self.is_file:
return f"{'!' if self.is_image else ''}[file]({self.url})"
return self.response.text
@staticmethod
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
mime = magic.from_buffer(file_bytes, mime=True)
if mime.startswith("image"):
return FileType.IMAGE, mime
elif mime.startswith("video"):
return FileType.VIDEO, mime
elif mime.startswith("audio"):
return FileType.AUDIO, mime
elif mime in ["application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"text/plain"]:
return FileType.DOCUMENT, mime
return None, None
@property
def files(self) -> list[FileObject]:
file_type, mime_type = self.get_file_type(self.response.content)
origin_file_type = mime_to_file_type(mime_type)
if self.is_file and file_type and origin_file_type:
file_obj = FileObject(
type=file_type,
url=self.url,
transfer_method=TransferMethod.REMOTE_URL.value,
origin_file_type=origin_file_type,
file_id=None,
is_file=True
)
file_obj.set_content(self.response.content)
return [
file_obj
]
return []
class HttpRequestNode(BaseNode):
"""
HTTP Request Workflow Node.
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
"body": VariableType.STRING,
"status_code": VariableType.NUMBER,
"headers": VariableType.OBJECT,
"files": VariableType.ARRAY_FILE,
"output": VariableType.STRING
}
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp)
return HttpRequestNodeOutput(
body=resp.text,
body=response.body,
status_code=resp.status_code,
headers=resp.headers,
files=response.files
).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}")

View File

@@ -0,0 +1,56 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/3/10 13:36
TRANSFORM_FILE_TYPE = {
'text/plain': 'document/text',
'text/markdown': 'document/markdown',
'text/x-markdown': 'document/x-markdown',
'application/pdf': 'document/pdf',
'application/msword': 'document/doc',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
'application/vnd.ms-powerpoint': 'document/ppt',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
}
ALLOWED_FILE_TYPES = [
'text/plain',
'text/markdown',
'text/x-markdown',
'application/pdf',
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'image/jpg',
'image/jpeg',
'image/png',
'image/gif',
'image/bmp',
'image/webp',
'image/svg+xml',
'video/mp4',
'video/quicktime',
'video/x-msvideo',
'video/x-matroska',
'video/webm',
'video/x-flv',
'video/x-ms-wmv',
'audio/mpeg',
'audio/wav',
'audio/ogg',
'audio/aac',
'audio/flac',
'audio/mp4',
'audio/x-ms-wma',
'audio/x-m4a',
]
def mime_to_file_type(mime_type):
if mime_type not in ALLOWED_FILE_TYPES:
return None
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)

View File

@@ -114,9 +114,16 @@ class FileObject(BaseModel):
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
_byte_content: bytes | None = None
def get_content(self):
return self._byte_content
def set_content(self, byte_content):
self._byte_content = byte_content
class BaseVariable(ABC):
"""Abstract base class for all workflow variables.

View File

@@ -51,6 +51,12 @@ class EndUser(Base):
growth_trajectory = Column(Text, nullable=True, comment="成长轨迹")
memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间")
# RAG存储模式专用字段 - RAG Storage Mode Fields
# storage_type = Column(String, nullable=True, default="neo4j", comment="存储模式类型: neo4j / rag")
rag_tags = Column(Text, nullable=True, comment="RAG模式下提取的标签列表JSON格式")
rag_personas = Column(Text, nullable=True, comment="RAG模式下提取的人物形象列表JSON格式")
rag_summary_updated_at = Column(DateTime, nullable=True, comment="RAG摘要/标签/人物形象最后更新时间")
# 与 App 的反向关系
app = relationship(
"App",

View File

@@ -1,10 +1,11 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from typing import List
from app.models.app_model import App
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.app_model import App
# 获取数据库专用日志器
db_logger = get_db_logger()
@@ -35,11 +36,27 @@ class AppRepository:
except Exception as e:
raise
def get_apps_by_name(self, app_name: str, app_type: str, workspace_id: uuid.UUID) -> List[App]:
try:
stmt = select(App).where(
App.name == app_name,
App.workspace_id == workspace_id,
App.type == app_type,
App.is_active.is_(True),
)
apps = self.db.execute(stmt).scalars().all()
return list(apps)
except Exception as e:
db_logger.error(f"查询名称 {app_name} 应用异常: {str(e)}")
raise
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_workspace_id(workspace_id)
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)

View File

@@ -220,6 +220,88 @@ class EndUserRepository:
db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}")
raise
def update_rag_summary_tags(
self,
end_user_id: uuid.UUID,
user_summary: str,
rag_tags: str,
rag_personas: str,
) -> bool:
"""更新RAG模式下的用户摘要、标签和人物形象缓存
Args:
end_user_id: 终端用户ID
user_summary: 用户摘要文本
rag_tags: 标签列表JSON字符串
rag_personas: 人物形象列表JSON字符串
Returns:
bool: 更新成功返回True否则返回False
"""
try:
updated_count = (
self.db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{
EndUser.user_summary: user_summary,
EndUser.rag_tags: rag_tags,
EndUser.rag_personas: rag_personas,
EndUser.rag_summary_updated_at: datetime.datetime.now(),
},
synchronize_session=False
)
)
self.db.commit()
if updated_count > 0:
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG摘要/标签/人物形象缓存")
return True
else:
db_logger.warning(f"未找到终端用户 {end_user_id}无法更新RAG摘要缓存")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"更新终端用户 {end_user_id} 的RAG摘要缓存时出错: {str(e)}")
raise
def update_rag_insight(
self,
end_user_id: uuid.UUID,
memory_insight: str,
) -> bool:
"""更新RAG模式下的记忆洞察缓存
Args:
end_user_id: 终端用户ID
memory_insight: 洞察文本
Returns:
bool: 更新成功返回True否则返回False
"""
try:
updated_count = (
self.db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{
EndUser.memory_insight: memory_insight,
EndUser.memory_insight_updated_at: datetime.datetime.now(),
},
synchronize_session=False
)
)
self.db.commit()
if updated_count > 0:
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG洞察缓存")
return True
else:
db_logger.warning(f"未找到终端用户 {end_user_id}无法更新RAG洞察缓存")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"更新终端用户 {end_user_id} 的RAG洞察缓存时出错: {str(e)}")
raise
def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
"""获取工作空间的所有终端用户

View File

@@ -8,6 +8,13 @@ import logging
from datetime import date, datetime, timedelta, timezone
from typing import Generator, Optional
class TimeFilterUnavailableError(Exception):
"""redis_client 不可用,无法执行时间轴筛选。
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
"""
import redis
from sqlalchemy import exists, not_, select
from sqlalchemy.orm import Session
@@ -113,7 +120,7 @@ class ImplicitEmotionsStorageRepository:
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
break
def get_users_needing_refresh(self, redis_client: Optional[redis.StrictRedis], batch_size: int = 100) -> Generator[str, None, None]:
def get_users_needing_refresh(self, redis_client: redis.StrictRedis, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
筛选逻辑:
@@ -123,27 +130,21 @@ class ImplicitEmotionsStorageRepository:
- 若 last_done > updated_at说明上次刷新后又有新记忆写入需要刷新
- 若 last_done <= updated_at说明已是最新跳过
如果 redis_client 为 None则降级为返回所有用户禁用时间过滤
Args:
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB,如果为 None 则禁用时间过滤
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB
batch_size: 每批次加载的数量
Raises:
TimeFilterUnavailableError: redis_client 为 None 时抛出,调用方可捕获并回退到 get_all_user_ids
Yields:
需要刷新的用户ID字符串
"""
from datetime import timezone
if redis_client is None:
raise TimeFilterUnavailableError("redis_client 不可用,无法执行时间轴筛选")
from redis.exceptions import RedisError
# 如果 Redis 不可用,降级为处理所有用户
if redis_client is None:
logger.warning(
"Redis 客户端不可用,时间过滤已禁用,将处理所有存量用户"
)
yield from self.get_all_user_ids(batch_size)
return
offset = 0
while True:
try:
@@ -178,16 +179,14 @@ class ImplicitEmotionsStorageRepository:
try:
CST = timezone(timedelta(hours=8))
last_done = datetime.fromisoformat(raw)
# 统一转为 CST naive 时间做比较
if last_done.tzinfo is None:
last_done = last_done.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
else:
# last_done 写入时已是 CST naive直接使用无需转换
if last_done.tzinfo is not None:
last_done = last_done.astimezone(CST).replace(tzinfo=None)
if updated_at is None:
yield end_user_id
continue
# updated_at 同样转为 CST naive
# updated_at 数据库存的是 UTC naive转为 CST naive 再比较
if updated_at.tzinfo is None:
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
else:

View File

@@ -1,10 +1,13 @@
from sqlalchemy.orm import Session, joinedload
from app.models.user_model import User
from typing import List, Optional
import uuid
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
from typing import List, Optional
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import select
from app.core.logging_config import get_db_logger
from app.models.user_model import User
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
from app.schemas.workspace_schema import WorkspaceCreate
# 获取数据库专用日志器
db_logger = get_db_logger()
@@ -19,7 +22,7 @@ class WorkspaceRepository:
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
"""创建工作空间"""
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
try:
db_workspace = Workspace(
name=workspace_data.name,
@@ -34,7 +37,8 @@ class WorkspaceRepository:
)
self.db.add(db_workspace)
self.db.flush()
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
db_logger.info(
f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
return db_workspace
except Exception as e:
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
@@ -43,7 +47,7 @@ class WorkspaceRepository:
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
"""根据ID获取工作空间"""
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
@@ -65,7 +69,7 @@ class WorkspaceRepository:
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
"""
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
@@ -89,7 +93,7 @@ class WorkspaceRepository:
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
try:
# 首先获取用户信息以获取 tenant_id
from app.models.user_model import User
@@ -97,7 +101,7 @@ class WorkspaceRepository:
if not user:
db_logger.warning(f"用户不存在: user_id={user_id}")
return []
if user.is_superuser:
# 超级用户获取对应tenantid所有工作空间
workspaces = (
@@ -109,7 +113,7 @@ class WorkspaceRepository:
)
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
return workspaces
# 获取用户作为成员的工作空间
member_workspaces = (
self.db.query(Workspace)
@@ -120,7 +124,7 @@ class WorkspaceRepository:
.order_by(Workspace.updated_at.desc())
.all()
)
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
return member_workspaces
except Exception as e:
@@ -130,7 +134,7 @@ class WorkspaceRepository:
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
"""获取租户的所有工作空间"""
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
try:
workspaces = (
self.db.query(Workspace)
@@ -144,14 +148,32 @@ class WorkspaceRepository:
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
raise
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
def get_workspaces_by_name(self, tenant_id: uuid.UUID, workspace_name: str) -> List[Workspace]:
try:
stmt = (
select(Workspace)
.where(
Workspace.tenant_id == tenant_id,
Workspace.name == workspace_name,
Workspace.is_active.is_(True)
)
)
workspaces = self.db.execute(stmt).scalars().all()
return list(workspaces)
except Exception as e:
db_logger.error(f"查询工作空间失败: workspace_name={workspace_name} - {str(e)}")
raise
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID,
role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
"""添加工作空间成员"""
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
try:
db_member = WorkspaceMember(
user_id=user_id,
workspace_id=workspace_id,
user_id=user_id,
workspace_id=workspace_id,
role=role
)
self.db.add(db_member)
@@ -165,7 +187,7 @@ class WorkspaceRepository:
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
"""获取工作空间成员"""
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.user_id == user_id,
@@ -173,7 +195,8 @@ class WorkspaceRepository:
WorkspaceMember.is_active.is_(True),
).first()
if member:
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
db_logger.debug(
f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
else:
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
return member
@@ -199,7 +222,7 @@ class WorkspaceRepository:
except Exception as e:
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
"""按成员ID获取工作空间成员并预加载 user 与 workspace 关系"""
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
@@ -214,7 +237,8 @@ class WorkspaceRepository:
.first()
)
if member:
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
db_logger.debug(
f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
else:
db_logger.debug(f"成员不存在: member_id={member_id}")
return member
@@ -222,7 +246,8 @@ class WorkspaceRepository:
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
raise
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[
WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id,
@@ -255,7 +280,7 @@ class WorkspaceRepository:
except Exception as e:
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
raise
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
@@ -271,7 +296,7 @@ class WorkspaceRepository:
except Exception as e:
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
raise
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
try:
member = self.db.query(WorkspaceMember).filter(
@@ -288,12 +313,18 @@ class WorkspaceRepository:
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
raise
# 保持向后兼容的函数
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
repo = WorkspaceRepository(db)
return repo.get_workspace_by_id(workspace_id)
def get_workspaces_by_name(db: Session, tenant_id: uuid.UUID, name: str) -> List[Workspace]:
repo = WorkspaceRepository(db)
return repo.get_workspaces_by_name(tenant_id, name)
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
repo = WorkspaceRepository(db)
return repo.get_workspaces_by_user(user_id)
@@ -315,7 +346,7 @@ def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UU
def add_member_to_workspace(
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
) -> WorkspaceMember:
repo = WorkspaceRepository(db)
return repo.add_member(workspace_id, user_id, role)
@@ -325,39 +356,43 @@ def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[Works
repo = WorkspaceRepository(db)
return repo.get_members_by_workspace(workspace_id)
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
repo = WorkspaceRepository(db)
return repo.get_member_by_id(member_id)
def update_member_role_in_workspace(
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
role: WorkspaceRole,
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
role: WorkspaceRole,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.update_member_role(workspace_id, user_id, role)
def remove_member_from_workspace(
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
db: Session,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.deactivate_member(workspace_id, user_id)
def remove_member_from_workspace_by_id(
db: Session,
member_id: uuid.UUID,
db: Session,
member_id: uuid.UUID,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.delete_member_by_id(member_id)
def update_member_role_by_id(
db: Session,
id: uuid.UUID,
role: WorkspaceRole,
db: Session,
id: uuid.UUID,
role: WorkspaceRole,
) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db)
return repo.update_member_role_by_id(id, role)

View File

@@ -45,11 +45,19 @@ class FileInput(BaseModel):
url: Optional[str] = Field(None, description="远程URLremote_url时必填")
file_type: Optional[str] = Field(None, description="具体文件格式如image/jpg、audio/wav、document/docx、video/mp4")
_content = None
def __init__(self, **data):
if "type" in data:
data['file_type'] = data['type']
super().__init__(**data)
def set_content(self, content: bytes):
self._content = content
def get_content(self) -> bytes | None:
return self._content
@field_validator("type", mode="before")
@classmethod
def validate_type(cls, v):

View File

@@ -0,0 +1,445 @@
"""应用 DSL 导入导出服务"""
import uuid
import datetime
from typing import Optional
import yaml
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.models import AgentConfig, MultiAgentConfig
from app.models.app_model import App, AppType
from app.models.app_release_model import AppRelease
from app.models.knowledge_model import Knowledge
from app.models.models_model import ModelConfig
from app.models.tool_model import ToolConfig as ToolConfigModel
from app.models.workflow_model import WorkflowConfig
from app.services.workflow_service import WorkflowService
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
class AppDslService:
def __init__(self, db: Session):
self.db = db
# ==================== 导出 ====================
def export_dsl(self, app_id: uuid.UUID, release_id: Optional[uuid.UUID] = None) -> tuple[str, str]:
"""构建应用 DSL yaml 字符串,返回 (yaml_str, filename)"""
app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
if not app:
raise ResourceNotFoundException("应用", str(app_id))
meta = {
"version": settings.SYSTEM_VERSION,
"platform": "MemoryBear",
"exported_at": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
}
app_meta = {
"name": app.name,
"description": app.description,
"icon": app.icon,
"icon_type": app.icon_type,
"type": app.type,
"tags": app.tags or [],
}
if release_id is not None:
return self._export_release(app, release_id, meta, app_meta)
return self._export_draft(app, meta, app_meta)
def _export_release(self, app: App, release_id: uuid.UUID, meta: dict, app_meta: dict) -> tuple[str, str]:
release = self.db.query(AppRelease).filter(
AppRelease.app_id == app.id,
AppRelease.id == release_id,
AppRelease.is_active.is_(True)
).first()
if not release:
raise ResourceNotFoundException("版本", str(release_id))
meta["release_version"] = release.version
meta["release_name"] = release.version_name
app_meta["name"] = release.name
app_meta["description"] = release.description
config_key = {
AppType.AGENT: "agent_config",
AppType.MULTI_AGENT: "multi_agent_config",
AppType.WORKFLOW: "workflow"
}.get(app.type, "config")
config_data = self._enrich_release_config(app.type, release.config or {})
dsl = {**meta, "app": app_meta, config_key: config_data}
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
if app_type == AppType.AGENT:
enriched = {**cfg}
if "default_model_config_id" in cfg:
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
if "knowledge_retrieval" in cfg:
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
if "tools" in cfg:
enriched["tools"] = self._enrich_tools(cfg["tools"])
return enriched
if app_type == AppType.MULTI_AGENT:
enriched = {**cfg}
if "default_model_config_id" in cfg:
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
if "master_agent_id" in cfg:
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
if "sub_agents" in cfg:
enriched["sub_agents"] = self._enrich_sub_agents(cfg["sub_agents"])
if "routing_rules" in cfg:
enriched["routing_rules"] = [
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
]
return enriched
return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
if app.type == AppType.WORKFLOW:
config = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app.id).first()
config_data = {
"variables": config.variables if config else [],
"edges": config.edges if config else [],
"nodes": config.nodes if config else [],
"execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [],
} if config else {}
dsl = {**meta, "app": app_meta, "workflow": config_data}
elif app.type == AppType.AGENT:
config = self.db.query(AgentConfig).filter(AgentConfig.app_id == app.id).first()
config_data = {
"system_prompt": config.system_prompt if config else None,
"model_parameters": self._to_dict(config.model_parameters) if config else None,
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
"knowledge_retrieval": self._enrich_knowledge_retrieval(config.knowledge_retrieval) if config else None,
"memory": config.memory if config else None,
"variables": config.variables if config else [],
"tools": self._enrich_tools(config.tools) if config else [],
"skills": config.skills if config else {},
} if config else {}
dsl = {**meta, "app": app_meta, "agent_config": config_data}
elif app.type == AppType.MULTI_AGENT:
config = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app.id).first()
config_data = {
"orchestration_mode": config.orchestration_mode if config else None,
"master_agent_name": config.master_agent_name if config else None,
"model_parameters": self._to_dict(config.model_parameters) if config else None,
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
"master_agent_ref": self._release_ref(config.master_agent_id) if config else None,
"sub_agents": self._enrich_sub_agents(config.sub_agents) if config else [],
"routing_rules": [
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (config.routing_rules or [])
] if config else [],
"execution_config": config.execution_config if config else {},
"aggregation_strategy": config.aggregation_strategy if config else "merge",
} if config else {}
dsl = {**meta, "app": app_meta, "multi_agent_config": config_data}
else:
raise BusinessException(f"不支持的应用类型: {app.type}", BizCode.BAD_REQUEST)
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{app.name}.yaml"
def _to_dict(self, value):
"""将 Pydantic 对象转为普通 dict供 yaml.dump 安全序列化"""
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump()
return value
def _model_ref(self, model_config_id) -> Optional[dict]:
if not model_config_id:
return None
m = self.db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
return {"id": str(model_config_id), "name": m.name, "provider": m.provider, "type": m.type} if m else {"id": str(model_config_id)}
def _kb_ref(self, kb_id) -> Optional[dict]:
if not kb_id:
return None
kb = self.db.query(Knowledge).filter(Knowledge.id == kb_id).first()
return {"id": str(kb_id), "name": kb.name} if kb else {"id": str(kb_id)}
def _tool_ref(self, tool_id) -> Optional[dict]:
if not tool_id:
return None
t = self.db.query(ToolConfigModel).filter(ToolConfigModel.id == tool_id).first()
return {"id": str(tool_id), "name": t.name, "tool_type": t.tool_type} if t else {"id": str(tool_id)}
def _enrich_knowledge_retrieval(self, kr: Optional[dict]) -> Optional[dict]:
if not kr:
return kr
kbs = [{**kb, "_ref": self._kb_ref(kb.get("kb_id"))} for kb in kr.get("knowledge_bases", [])]
return {**kr, "knowledge_bases": kbs}
def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _agent_ref(self, agent_id) -> Optional[dict]:
if not agent_id:
return None
a = self.db.query(App).filter(App.id == agent_id).first()
return {"id": str(agent_id), "name": a.name} if a else {"id": str(agent_id)}
def _release_ref(self, release_id) -> Optional[dict]:
if not release_id:
return None
r = self.db.query(AppRelease).filter(AppRelease.id == release_id).first()
return {"id": str(release_id), "name": r.name, "version": r.version, "app_id": str(r.app_id)} if r else {"id": str(release_id)}
def _enrich_sub_agents(self, sub_agents: list) -> list:
return [{**s, "_ref": self._agent_ref(s.get("agent_id"))} for s in (sub_agents or [])]
# ==================== 导入 ====================
def import_dsl(
self,
dsl: dict,
workspace_id: uuid.UUID,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
) -> tuple[App, list[str]]:
"""解析 DSL创建应用及配置返回 (new_app, warnings)"""
app_meta = dsl.get("app", {})
app_type = app_meta.get("type")
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.BAD_REQUEST)
warnings: list[str] = []
now = datetime.datetime.now()
new_app = App(
id=uuid.uuid4(),
workspace_id=workspace_id,
created_by=user_id,
name=self._unique_app_name(app_meta.get("name", "导入应用"), workspace_id, app_type),
description=app_meta.get("description"),
icon=app_meta.get("icon"),
icon_type=app_meta.get("icon_type"),
type=app_type,
visibility="private",
status="draft",
tags=app_meta.get("tags", []),
is_active=True,
created_at=now,
updated_at=now,
)
self.db.add(new_app)
self.db.flush()
if app_type == AppType.AGENT:
cfg = dsl.get("agent_config") or {}
self.db.add(AgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
system_prompt=cfg.get("system_prompt"),
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
variables=cfg.get("variables", []),
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
skills=cfg.get("skills", {}),
is_active=True,
created_at=now,
updated_at=now,
))
elif app_type == AppType.MULTI_AGENT:
cfg = dsl.get("multi_agent_config") or {}
self.db.add(MultiAgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
master_agent_name=cfg.get("master_agent_name"),
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
master_agent_id=self._resolve_release(cfg.get("master_agent_ref"), warnings),
sub_agents=self._resolve_sub_agents(cfg.get("sub_agents", []), warnings),
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
execution_config=cfg.get("execution_config", {}),
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
is_active=True,
created_at=now,
updated_at=now,
))
elif app_type == AppType.WORKFLOW:
adapter = MemoryBearAdapter(dsl)
if not adapter.validate_config():
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
result = adapter.parse_workflow()
for e in result.errors:
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
for w in result.warnings:
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
wf = dsl.get("workflow") or {}
WorkflowService(self.db).create_workflow_config(
app_id=new_app.id,
nodes=[n.model_dump() for n in result.nodes],
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
triggers=wf.get("triggers", []),
validate=False,
)
self.db.commit()
self.db.refresh(new_app)
return new_app, warnings
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
existing = {r[0] for r in self.db.query(App.name).filter(
App.workspace_id == workspace_id,
App.type == app_type,
App.is_active.is_(True)
).all()}
if name not in existing:
return name
counter = 1
while f"{name}({counter})" in existing:
counter += 1
return f"{name}({counter})"
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
if not ref:
return None
q = self.db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == ref.get("name"),
ModelConfig.is_active.is_(True)
)
if ref.get("provider"):
q = q.filter(ModelConfig.provider == ref["provider"])
if ref.get("type"):
q = q.filter(ModelConfig.type == ref["type"])
m = q.first()
if not m:
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return m.id if m else None
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
return None
kb = self.db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.name == ref.get("name")
).first()
if not kb:
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return str(kb.id) if kb else None
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
return None
q = self.db.query(ToolConfigModel).filter(
ToolConfigModel.tenant_id == tenant_id,
ToolConfigModel.name == ref.get("name")
)
if ref.get("tool_type"):
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
t = q.first()
if not t:
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return str(t.id) if t else None
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
if not ref:
return None
r = self.db.query(AppRelease).filter(
AppRelease.app_id == ref.get("app_id"),
AppRelease.version == ref.get("version"),
AppRelease.is_active.is_(True)
).first()
if not r:
warnings.append(f"主 Agent 发布版本 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return r.id if r else None
def _resolve_sub_agents(self, sub_agents: list, warnings: list) -> list:
result = []
for s in (sub_agents or []):
ref = s.get("_ref")
entry = {k: v for k, v in s.items() if k != "_ref"}
if ref:
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
if not a:
warnings.append(f"子 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
entry["agent_id"] = str(a.id) if a else None
result.append(entry)
return result
def _resolve_routing_rules(self, rules: Optional[list], warnings: list) -> Optional[list]:
if rules is None:
return None
result = []
for r in rules:
ref = r.get("_ref")
entry = {k: v for k, v in r.items() if k != "_ref"}
if ref:
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
if not a:
warnings.append(f"路由目标 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
entry["target_agent_id"] = str(a.id) if a else None
result.append(entry)
return result
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
if not kr:
return kr
resolved_kbs = []
for kb in kr.get("knowledge_bases", []):
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
entry = {k: v for k, v in kb.items() if k != "_ref"}
resolved_id = self._resolve_kb(ref, workspace_id, warnings)
if resolved_id is None:
continue
entry["kb_id"] = resolved_id
resolved_kbs.append(entry)
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
if not memory:
return memory
config_id = memory.get("memory_config_id") or memory.get("memory_content")
if not config_id:
return memory
try:
config_uuid = uuid.UUID(str(config_id))
except (ValueError, AttributeError):
exists = self.db.query(MemoryConfigModel).filter(
MemoryConfigModel.config_id_old == int(config_id),
MemoryConfigModel.workspace_id == workspace_id
).first()
if not exists:
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
return {**memory, "memory_config_id": None, "enabled": False}
return memory
exists = self.db.query(MemoryConfigModel).filter(
MemoryConfigModel.config_id == config_uuid,
MemoryConfigModel.workspace_id == workspace_id
).first()
if not exists:
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
return {**memory, "memory_config_id": None, "enabled": False}
return memory
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
result = []
for t in (tools or []):
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
entry = {k: v for k, v in t.items() if k != "_ref"}
resolved_id = self._resolve_tool(ref, tenant_id, warnings)
if resolved_id is None:
continue
entry["tool_id"] = resolved_id
result.append(entry)
return result

View File

@@ -33,7 +33,7 @@ from app.models import (
Workspace,
)
from app.models.app_model import AppStatus, AppType
from app.repositories.app_repository import get_apps_by_id
from app.repositories.app_repository import get_apps_by_id, AppRepository
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import app_schema
from app.schemas.workflow_schema import WorkflowConfigUpdate
@@ -59,6 +59,7 @@ class AppService:
db: 数据库会话
"""
self.db = db
self.app_repo = AppRepository(self.db)
# ==================== 私有辅助方法 ====================
@@ -521,6 +522,9 @@ class AppService:
"创建应用",
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
)
apps = self.app_repo.get_apps_by_name(data.name, data.type, workspace_id)
if apps:
raise BusinessException(message="已存在同名应用", code=BizCode.RESOURCE_ALREADY_EXISTS)
try:
now = datetime.datetime.now()
@@ -1368,6 +1372,15 @@ class AppService:
if not agent_cfg:
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
miss_params = []
if agent_cfg.default_model_config_id is None:
miss_params.append("model config")
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
miss_params.append("memory config")
if miss_params:
raise BusinessException(f"{', '.join(miss_params)} is required")
config = {
"system_prompt": agent_cfg.system_prompt,
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),

View File

@@ -1165,6 +1165,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
logger.info(f"Getting connected config for end_user: {end_user_id}")
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
@@ -1179,10 +1180,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
if not app:
logger.warning(f"App not found: {app_id}")
raise ValueError(f"应用不存在: {app_id}")
if not app.current_release_id:
logger.warning(f"No current release for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
# TODO: temp fix for draft run
# if not app.current_release_id:
# logger.warning(f"No current release for app: {app_id}")
# raise ValueError(f"应用未发布: {app_id}")
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id
@@ -1223,7 +1224,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.models.memory_config_model import (
MemoryConfig as MemoryConfigModel,
)
existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config:
@@ -1257,7 +1260,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(app.current_release_id),
"release_id": str(app.current_release_id) if app.current_release_id else None,
"memory_config_id": memory_config_id,
"workspace_id": str(app.workspace_id)
}

View File

@@ -107,29 +107,19 @@ def _validate_config_id(config_id, db: Session = None):
)
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
# 使用懒加载函数避免模块级循环导入
def _get_builtin_pruning_scenes() -> set:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
return set(SceneConfigRegistry.get_all_scenes())
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
"""从 ontology_class 表加载场景类型名称列表,用于注入提示词
Args:
db: 数据库会话
scene_id: 本体场景 UUID
pruning_scene: 语义剪枝场景名称
pruning_scene: 语义剪枝场景名称(保留参数,暂未使用)
Returns:
class_name 字符串列表,或 None内置场景 / 无数据时)
class_name 字符串列表,或 None无数据时
"""
if not scene_id:
return None
# 内置场景走 SceneConfigRegistry不需要注入类型列表
if pruning_scene in _get_builtin_pruning_scenes():
return None
try:
from app.repositories.ontology_class_repository import OntologyClassRepository
repo = OntologyClassRepository(db)

View File

@@ -535,7 +535,8 @@ def get_users_total_chunk_batch(
def get_rag_content(
end_user_id: str,
limit: int,
page: int,
pagesize: int,
db: Session,
current_user: User
) -> dict:
@@ -543,9 +544,9 @@ def get_rag_content(
先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id,
然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容
接着对获取的内容进行提取只要page_content的内容
最后返回数据
最后返回分页数据
"""
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
@@ -562,63 +563,76 @@ def get_rag_content(
if not documents:
business_logger.warning(f"未找到文件: {file_name}")
return {
"total": 0,
"contents": []
"page": {
"page": page,
"pagesize": pagesize,
"total": 0,
"hasnext": False,
},
"items": []
}
business_logger.info(f"找到 {len(documents)} 个文档记录")
# 3. 获取所有chunks的page_content
all_contents = []
total_chunks = 0
# 3. 按全局偏移量计算当前页数据
# 全局偏移范围:[offset_start, offset_end)
offset_start = (page - 1) * pagesize
offset_end = offset_start + pagesize
global_total = 0 # 所有文档的 chunk 总数
page_contents = [] # 当前页的内容
for document in documents:
try:
# 获取知识库信息
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
if not kb:
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
continue
# 初始化向量服务
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb)
# 获取该文档的所有chunks分页获取
page = 1
pagesize = 100 # 每页100条
# 先用 pagesize=1 获取该文档的 chunk 总数
doc_total, _ = vector_service.search_by_segment(
document_id=str(document.id),
query=None,
pagesize=1,
page=1,
asc=True
)
while True:
total, items = vector_service.search_by_segment(
doc_offset_start = global_total # 该文档在全局中的起始偏移
doc_offset_end = global_total + doc_total # 该文档在全局中的结束偏移
global_total += doc_total
# 当前页与该文档无交集,跳过
if doc_offset_end <= offset_start or doc_offset_start >= offset_end:
continue
# 计算需要从该文档取的局部范围
local_start = max(offset_start - doc_offset_start, 0)
local_end = min(offset_end - doc_offset_start, doc_total)
need_count = local_end - local_start
# 换算成 ES 分页参数ES page 从1开始
es_page = (local_start // pagesize) + 1
es_offset_in_page = local_start % pagesize
fetched = []
while len(fetched) < es_offset_in_page + need_count:
_, items = vector_service.search_by_segment(
document_id=str(document.id),
query=None,
pagesize=pagesize,
page=page,
page=es_page,
asc=True
)
if not items:
break
# 提取page_content
for item in items:
all_contents.append(item.page_content)
total_chunks += 1
# # 如果达到limit限制直接返回
# if limit > 0 and total_chunks >= limit:
# business_logger.info(f"已达到limit限制: {limit}")
# return {
# "total": total_chunks,
# "contents": all_contents[:limit]
# }
# 检查是否还有下一页
if page * pagesize >= total:
break
page += 1
fetched.extend(items)
es_page += 1
business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks")
slice_items = fetched[es_offset_in_page: es_offset_in_page + need_count]
page_contents.extend([item.page_content for item in slice_items])
except Exception as e:
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
@@ -626,11 +640,16 @@ def get_rag_content(
# 4. 返回结果
result = {
"total": total_chunks,
"contents": all_contents[:limit] if limit > 0 else all_contents
"page": {
"page": page,
"pagesize": pagesize,
"total": global_total,
"hasnext": offset_end < global_total,
},
"items": page_contents
}
business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])}")
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)}")
return result
except Exception as e:
@@ -646,59 +665,26 @@ async def get_chunk_summary_and_tags(
current_user: User
) -> dict:
"""
获取chunk的总结、标签和人物形象
Args:
end_user_id: 宿主ID
limit: 返回的chunk数量限制
max_tags: 最大标签数量
db: 数据库会话
current_user: 当前用户
Returns:
包含summary、tags和personas的字典
纯读库从end_user表返回RAG摘要、标签和人物形象缓存。
无数据时返回空结构不触发LLM生成。
"""
business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
try:
# 1. 获取chunk内容
rag_content = get_rag_content(end_user_id, limit, db, current_user)
chunks = rag_content.get("contents", [])
if not chunks:
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
return {
"summary": "暂无内容",
"tags": [],
"personas": []
}
# 2. 导入RAG工具函数
from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona
# 3. 并发生成摘要、提取标签和人物形象
import asyncio
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
# 4. 格式化标签数据
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
result = {
"summary": summary,
"tags": tags,
"personas": personas
}
business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象")
return result
except Exception as e:
business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}")
raise
import json
from app.repositories.end_user_repository import EndUserRepository
business_logger.info(f"读取chunk摘要/标签/人物形象缓存: end_user_id={end_user_id}")
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if not end_user:
return {"summary": "", "tags": [], "personas": [], "generated": False}
return {
"summary": end_user.user_summary or "",
"tags": json.loads(end_user.rag_tags) if end_user.rag_tags else [],
"personas": json.loads(end_user.rag_personas) if end_user.rag_personas else [],
"generated": bool(end_user.user_summary),
}
async def get_chunk_insight(
@@ -708,43 +694,98 @@ async def get_chunk_insight(
current_user: User
) -> dict:
"""
获取chunk的洞察分析
Args:
end_user_id: 宿主ID
limit: 返回的chunk数量限制
db: 数据库会话
current_user: 当前用户
Returns:
包含insight的字典
纯读库从end_user表返回RAG洞察缓存。
无数据时返回空结构不触发LLM生成。
"""
business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
try:
# 1. 获取chunk内容
rag_content = get_rag_content(end_user_id, limit, db, current_user)
chunks = rag_content.get("contents", [])
if not chunks:
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
return {
"insight": "暂无足够数据生成洞察报告"
}
# 2. 导入RAG工具函数
from app.core.rag_utils import generate_chunk_insight
# 3. 生成洞察
insight = await generate_chunk_insight(chunks, max_chunks=limit)
result = {
"insight": insight
}
business_logger.info("成功获取chunk洞察")
return result
except Exception as e:
business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}")
raise
from app.repositories.end_user_repository import EndUserRepository
business_logger.info(f"读取chunk洞察缓存: end_user_id={end_user_id}")
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if not end_user:
return {"insight": "", "behavior_pattern": "", "key_findings": "", "growth_trajectory": "", "generated": False}
return {
"insight": end_user.memory_insight or "",
"behavior_pattern": end_user.behavior_pattern or "",
"key_findings": end_user.key_findings or "",
"growth_trajectory": end_user.growth_trajectory or "",
"generated": bool(end_user.memory_insight),
}
async def generate_rag_profile(
end_user_id: str,
limit: int,
max_tags: int,
db: Session,
current_user: User,
) -> dict:
"""
生产接口为RAG存储模式的end_user全量重新生成并持久化完整画像数据。
每次调用都会重新生成,覆盖已有数据。
生成内容:
- user_summary / rag_tags / rag_personas
- memory_insight / behavior_pattern / key_findings / growth_trajectory
"""
import json
import asyncio
from app.repositories.end_user_repository import EndUserRepository
from app.core.rag_utils import (
generate_chunk_summary,
extract_chunk_tags,
extract_chunk_persona,
generate_chunk_insight_sections,
)
business_logger.info(f"开始生产RAG画像: end_user_id={end_user_id}, 操作者: {current_user.username}")
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if not end_user:
raise ValueError(f"end_user {end_user_id} 不存在")
rag_content = get_rag_content(end_user_id, page=1, pagesize=limit, db=db, current_user=current_user)
chunks = rag_content.get("items", [])
if not chunks:
business_logger.warning(f"未找到chunk内容无法生产RAG画像: end_user_id={end_user_id}")
raise ValueError("暂无chunk内容无法生成画像")
summary, tags_with_freq, personas, insight_sections = await asyncio.gather(
generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id),
extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id),
extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id),
generate_chunk_insight_sections(chunks, max_chunks=limit, end_user_id=end_user_id),
)
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
repo.update_rag_summary_tags(
end_user_id=end_user.id,
user_summary=summary,
rag_tags=json.dumps(tags, ensure_ascii=False),
rag_personas=json.dumps(personas, ensure_ascii=False),
)
repo.update_memory_insight(
end_user_id=end_user.id,
memory_insight=insight_sections.get("memory_insight", ""),
behavior_pattern=insight_sections.get("behavior_pattern", ""),
key_findings=insight_sections.get("key_findings", ""),
growth_trajectory=insight_sections.get("growth_trajectory", ""),
)
business_logger.info(f"RAG画像生产完成: end_user_id={end_user_id}, tags={len(tags)}, personas={len(personas)}")
return {
"end_user_id": end_user_id,
"summary_length": len(summary),
"tags_count": len(tags),
"personas_count": len(personas),
"insight_generated": bool(insight_sections.get("memory_insight")),
}

View File

@@ -8,32 +8,42 @@
- Bedrock/Anthropic: 仅支持 base64 格式
- OpenAI: 支持 URL 和 base64 格式
"""
import uuid
import httpx
import base64
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from docx import Document
import io
import PyPDF2
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import PyPDF2
import httpx
import magic
from docx import Document
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.services.audio_transcription_service import AudioTranscriptionService
logger = get_business_logger()
TEXT_MIME = ['text/plain', 'text/x-markdown']
PDF_MIME = ['application/pdf']
DOC_MIME = [
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
]
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
def __init__(self, file: FileInput):
self.file = file
@abstractmethod
async def format_image(self, url: str) -> Dict[str, Any]:
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""格式化图片"""
pass
@@ -43,7 +53,7 @@ class MultimodalFormatStrategy(ABC):
pass
@abstractmethod
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""格式化音频"""
pass
@@ -56,7 +66,7 @@ class MultimodalFormatStrategy(ABC):
class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""通义千问图片格式:{"type": "image", "image": "url"}"""
return {
"type": "image",
@@ -70,7 +80,13 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
async def format_audio(
self,
file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
通义千问音频格式
- 原生支持: qwen-audio 系列
@@ -98,44 +114,37 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
class BedrockFormatStrategy(MultimodalFormatStrategy):
"""Bedrock/Anthropic 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
"""
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
content = response.content
self.file.set_content(content)
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
content_type = magic.from_buffer(content, mime=True)
media_type = content_type if content_type.startswith("image/") else "image/jpeg"
base64_data = base64.b64encode(content).decode("utf-8")
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data
}
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
@@ -152,7 +161,12 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
}
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
async def format_audio(
self, file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
Bedrock/Anthropic 音频格式
不支持原生音频,必须转录为文本
@@ -178,7 +192,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
return {
"type": "image_url",
@@ -194,7 +208,13 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
async def format_audio(
self,
file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
OpenAI 音频格式
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
@@ -208,31 +228,35 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
# OpenAI 音频需要 base64 编码
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
base64_audio = base64.b64encode(audio_data).decode('utf-8')
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
ct = response.headers.get("content-type", "")
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
audio_data = content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
self.file.set_content(audio_data)
base64_audio = base64.b64encode(audio_data).decode('utf-8')
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
content_type = magic.from_buffer(audio_data, mime=True)
file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
}
except Exception as e:
logger.error(f"下载音频失败: {e}")
return {
@@ -262,7 +286,8 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
enable_audio_transcription: bool = False, is_omni: bool = False):
"""
初始化多模态服务
@@ -305,10 +330,9 @@ class MultimodalService:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
strategy = strategy_class()
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
try:
if file.type == FileType.IMAGE:
content = await self._process_image(file, strategy)
@@ -355,7 +379,7 @@ class MultimodalService:
"""
try:
url = await self.get_file_url(file)
return await strategy.format_image(url)
return await strategy.format_image(url, content=file.get_content())
except Exception as e:
logger.error(f"处理图片失败: {e}", exc_info=True)
return {
@@ -415,11 +439,13 @@ class MultimodalService:
# 远程文档暂不支持提取
return {
"type": "text",
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
}
else:
# 本地文件,提取文本内容
text = await self._extract_document_text(file.upload_file_id)
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -454,7 +480,7 @@ class MultimodalService:
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
return await strategy.format_audio(file.file_type, url, transcription)
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
except Exception as e:
logger.error(f"处理音频失败: {e}", exc_info=True)
return {
@@ -500,8 +526,6 @@ class MultimodalService:
return file.url
else:
file_id = file.upload_file_id
print("="*50)
print("file_id",file_id)
# 查询 FileMetadata
file_metadata = self.db.query(FileMetadata).filter(
@@ -519,66 +543,44 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
async def _extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容
Args:
file_id: 文件ID
file: 文件输入
Returns:
str: 提取的文本内容
"""
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
file_ext = file_metadata.file_ext.lower()
server_url = settings.FILE_LOCAL_SERVER_URL
file_url = f"{server_url}/storage/permanent/{file_id}"
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(file_url)
elif file_ext == '.pdf':
return await self._extract_pdf_text(file_url)
elif file_ext in ['.doc', '.docx']:
return await self._extract_word_text(file_url)
else:
return f"[不支持的文档格式: {file_ext}]"
@staticmethod
async def _read_text_file(file_url: str) -> str:
"""读取纯文本文件"""
try:
# 下载文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
return response.text
file_content = file.get_content()
if not file_content:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file.url)
response.raise_for_status()
file_content = response.content
file.set_content(file_content)
file_mime_type = magic.from_buffer(file_content, mime=True)
if file_mime_type in TEXT_MIME:
return file_content.decode("utf-8")
elif file_mime_type in PDF_MIME:
return await self._extract_pdf_text(file_content)
elif file_mime_type in DOC_MIME:
return await self._extract_word_text(file_content)
else:
return f"[Unsupported file type: {file_mime_type}]"
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
logger.error(f"Failed to load file. - {e}")
return "[Failed to load file.]"
@staticmethod
async def _extract_pdf_text(file_url: str) -> str:
async def _extract_pdf_text(file_content: bytes) -> str:
"""提取 PDF 文本"""
try:
# 下载 PDF 文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
pdf_data = response.content
# 使用 BytesIO 读取 PDF
text_parts = []
pdf_file = io.BytesIO(pdf_data)
pdf_file = io.BytesIO(file_content)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text_parts.append(page.extract_text())
@@ -588,17 +590,11 @@ class MultimodalService:
return f"[PDF 提取失败: {str(e)}]"
@staticmethod
async def _extract_word_text(file_url: str) -> str:
async def _extract_word_text(file_content: bytes) -> str:
"""提取 Word 文档文本"""
try:
# 下载 Word 文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
word_data = response.content
# 使用 BytesIO 读取 Word 文档
word_file = io.BytesIO(word_data)
word_file = io.BytesIO(file_content)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)

View File

@@ -120,7 +120,8 @@ async def run_pilot_extraction(
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
"llm_model_id": str(memory_config.llm_model_id),
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
"ontology_classes": memory_config.ontology_classes,
}
config = PruningConfig(**pruning_config_dict)

View File

@@ -93,7 +93,44 @@ class ToolService:
if query.first():
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
def create_tool(
def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]):
"""检查MCP工具是否重复市场来源按market_id+market_config_id+mcp_service_id判断名称无关自建按name+tool_type判断"""
from app.models.tool_model import MCPSourceChannel
source_channel = config.get("source_channel")
is_market_source = (
source_channel is not None
and source_channel != MCPSourceChannel.SELF_HOSTED
)
if is_market_source:
exists = (
self.db.query(ToolConfig)
.join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id)
.filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool_type,
MCPToolConfig.source_channel == source_channel,
MCPToolConfig.market_id == config.get("market_id"),
MCPToolConfig.market_config_id == config.get("market_config_id"),
MCPToolConfig.mcp_service_id == config.get("mcp_service_id"),
)
.first()
)
if exists:
raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME)
else:
exists = (
self.db.query(ToolConfig)
.filter(
ToolConfig.name == name,
ToolConfig.tool_type == tool_type,
ToolConfig.tenant_id == tenant_id,
)
.first()
)
if exists:
raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME)
async def create_tool(
self,
name: str,
tool_type: ToolType,
@@ -106,7 +143,19 @@ class ToolService:
"""创建工具"""
if tool_type == ToolType.BUILTIN:
raise ValueError("内置工具不允许创建")
self._check_name_duplicate(name, tool_type, tenant_id)
cfg = config or {}
if tool_type == ToolType.MCP:
self._check_mcp_duplicate(name, tool_type, tenant_id, cfg)
# 创建前测试连接
test_result = await self._test_mcp_connection_by_config(cfg)
if not test_result["success"]:
raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER)
# 将发现的工具列表写回 config
if "available_tools" in test_result:
cfg["available_tools"] = test_result["available_tools"]
else:
self._check_name_duplicate(name, tool_type, tenant_id)
try:
# 创建基础配置
@@ -117,19 +166,22 @@ class ToolService:
tool_type=tool_type.value,
tenant_id=tenant_id,
status=ToolStatus.AVAILABLE.value,
config_data=config or {},
config_data=cfg,
tags=tags
)
self.db.add(tool_config)
self.db.flush()
# 创建类型特定配置
self._create_type_config(tool_config, config or {})
self._create_type_config(tool_config, cfg)
self.db.commit()
logger.info(f"工具创建成功: {tool_config.id}")
return str(tool_config.id)
except BusinessException:
self.db.rollback()
raise
except Exception as e:
self.db.rollback()
logger.error(f"创建工具失败: {e}")
@@ -1165,6 +1217,27 @@ class ToolService:
logger.error(f"加载内置工具配置失败: {e}")
return {}
async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""根据配置参数直接测试MCP连接创建前调用无需已存在的工具记录"""
server_url = config.get("server_url")
if not server_url:
return {"success": False, "message": "server_url不能为空"}
connection_config = config.get("connection_config") or {}
try:
test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config)
if not test_result["success"]:
return test_result
success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config)
if not success_flag:
return {"success": False, "message": f"获取工具列表失败: {error}"}
tool_list = [
{tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}}
for tool in tools if tool.get("name")
]
return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list}
except Exception as e:
return {"success": False, "message": f"连接测试异常: {str(e)}"}
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试MCP连接并自动同步工具列表"""
try:

View File

@@ -458,7 +458,7 @@ class WorkflowService:
type=file.type,
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method,
file_id=str(file.upload_file_id),
file_id=str(file.upload_file_id) if file.upload_file_id else None,
origin_file_type=file.file_type,
is_file=True
).model_dump()

View File

@@ -2,11 +2,11 @@ import datetime
import hashlib
import secrets
import uuid
from os import getenv
from typing import List, Optional
from sqlalchemy.orm import Session
from app.config.default_ontology_initializer import DefaultOntologyInitializer
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException
@@ -30,17 +30,15 @@ from app.schemas.workspace_schema import (
WorkspaceModelsUpdate,
WorkspaceUpdate,
)
from app.config.default_ontology_initializer import DefaultOntologyInitializer
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
from dotenv import load_dotenv
load_dotenv()
def switch_workspace(
db: Session,
workspace_id: uuid.UUID,
user: User,
db: Session,
workspace_id: uuid.UUID,
user: User,
):
"""切换工作空间"""
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
@@ -60,31 +58,32 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
def delete_workspace_member(
db: Session,
workspace_id: uuid.UUID,
member_id: uuid.UUID,
user: User,
):
"""删除工作空间成员"""
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
_check_workspace_admin_permission(db, workspace_id, user)
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
if not workspace_member:
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
def delete_workspace_member(
db: Session,
workspace_id: uuid.UUID,
member_id: uuid.UUID,
user: User,
):
"""删除工作空间成员"""
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
_check_workspace_admin_permission(db, workspace_id, user)
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
if not workspace_member:
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
if workspace_member.workspace_id != workspace_id:
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
if workspace_member.workspace_id != workspace_id:
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}",
BizCode.WORKSPACE_NOT_FOUND)
try:
workspace_member.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
try:
workspace_member.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
@@ -102,19 +101,19 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
"""
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
# Ensure each neo4j workspace has a default memory config
for workspace in workspaces:
if workspace.storage_type == 'neo4j':
_ensure_default_memory_config(db, workspace)
_ensure_default_ontology_scenes(db, workspace)
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
return workspaces
def _create_workspace_only(
db: Session, workspace: WorkspaceCreate, owner: User
db: Session, workspace: WorkspaceCreate, owner: User
) -> Workspace:
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
@@ -138,9 +137,14 @@ def create_workspace(
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
f"storage_type: {workspace.storage_type}"
)
llm=workspace.llm
embedding=workspace.embedding
rerank=workspace.rerank
if workspace_repository.get_workspaces_by_name(db=db, name=workspace.name, tenant_id=user.tenant_id):
raise BusinessException(
message="同名工作空间已存在",
code=BizCode.RESOURCE_ALREADY_EXISTS
)
llm = workspace.llm
embedding = workspace.embedding
rerank = workspace.rerank
try:
# Create the workspace without adding any members
business_logger.debug(f"创建工作空间: {workspace.name}")
@@ -159,26 +163,26 @@ def create_workspace(
success, error_msg = initializer.initialize_default_scenes(
db_workspace.id, language=language
)
if success:
business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
)
# 获取默认场景ID优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
# 获取默认场景ID优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE,
ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 优先尝试获取教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
if education_scene:
default_scene_id = education_scene.scene_id
default_scene_name = education_scene.scene_name
@@ -189,7 +193,7 @@ def create_workspace(
# 如果教育场景不存在,尝试获取情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
if companion_scene:
default_scene_id = companion_scene.scene_id
default_scene_name = companion_scene.scene_name
@@ -256,10 +260,10 @@ def create_workspace(
avatar='',
type=KnowledgeType.General,
permission_id=PermissionType.Memory,
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
embedding_id=embedding,
reranker_id=rerank,
llm_id=llm,
image2text_id=llm,
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 256,
@@ -294,7 +298,7 @@ def create_workspace(
business_logger.info(
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
)
return db_workspace
except Exception as e:
@@ -304,11 +308,11 @@ def create_workspace(
def update_workspace(
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
) -> Workspace:
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
try:
# 更新工作空间
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
@@ -328,7 +332,7 @@ def update_workspace(
def get_workspace_members(
db: Session, workspace_id: uuid.UUID, user: User
db: Session, workspace_id: uuid.UUID, user: User
) -> List[WorkspaceMember]:
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
@@ -372,7 +376,6 @@ def get_workspace_members(
return members
# ==================== 邀请相关服务方法 ====================
def _generate_invite_token() -> tuple[str, str]:
@@ -465,13 +468,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
def create_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
user: User
db: Session,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
user: User
) -> WorkspaceInviteResponse:
"""创建工作空间邀请"""
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
business_logger.info(
f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
try:
# 检查权限
@@ -534,17 +538,18 @@ def create_workspace_invite(
except Exception as e:
db.rollback()
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
business_logger.error(
f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
raise
def get_workspace_invites(
db: Session,
workspace_id: uuid.UUID,
user: User,
status: Optional[InviteStatus] = None,
limit: int = 50,
offset: int = 0
db: Session,
workspace_id: uuid.UUID,
user: User,
status: Optional[InviteStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[WorkspaceInviteResponse]:
"""获取工作空间邀请列表"""
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
@@ -605,9 +610,9 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
def accept_workspace_invite(
db: Session,
accept_request: InviteAcceptRequest,
user: User
db: Session,
accept_request: InviteAcceptRequest,
user: User
) -> dict:
"""接受工作空间邀请"""
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
@@ -695,7 +700,8 @@ def accept_workspace_invite(
# 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
business_logger.info(
f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
return {
"message": "Successfully joined the workspace",
@@ -710,13 +716,14 @@ def accept_workspace_invite(
def revoke_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_id: uuid.UUID,
user: User
db: Session,
workspace_id: uuid.UUID,
invite_id: uuid.UUID,
user: User
) -> dict:
"""撤销工作空间邀请"""
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
business_logger.info(
f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
try:
# 检查权限
@@ -745,13 +752,14 @@ def revoke_workspace_invite(
def update_workspace_member_roles(
db: Session,
workspace_id: uuid.UUID,
updates: List[WorkspaceMemberUpdate],
user: User,
db: Session,
workspace_id: uuid.UUID,
updates: List[WorkspaceMemberUpdate],
user: User,
) -> List[WorkspaceMember]:
"""更新工作空间成员角色"""
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
business_logger.info(
f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
# 检查管理员权限
_check_workspace_admin_permission(db, workspace_id, user)
@@ -765,7 +773,8 @@ def update_workspace_member_roles(
for upd in updates:
# 检查成员是否存在
if upd.id not in member_map:
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}",
BizCode.WORKSPACE_MEMBER_NOT_FOUND)
member = member_map[upd.id]
@@ -917,10 +926,10 @@ def get_workspace_models_configs(
def update_workspace_models_configs(
db: Session,
workspace_id: uuid.UUID,
models_update: WorkspaceModelsUpdate,
user: User,
db: Session,
workspace_id: uuid.UUID,
models_update: WorkspaceModelsUpdate,
user: User,
) -> Workspace:
"""更新工作空间的模型配置llm, embedding, rerank
@@ -968,8 +977,8 @@ def update_workspace_models_configs(
def _fill_workspace_configs_model_defaults(
db: Session,
workspace: Workspace
db: Session,
workspace: Workspace
) -> None:
"""Fill empty model fields for all memory configs in a workspace.
@@ -981,43 +990,43 @@ def _fill_workspace_configs_model_defaults(
workspace: The workspace containing default model settings
"""
from app.models.memory_config_model import MemoryConfig
# Get all configs for this workspace
configs = db.query(MemoryConfig).filter(
MemoryConfig.workspace_id == workspace.id
).all()
if not configs:
return
# Map of memory_config field -> workspace field
model_field_mappings = [
("llm_id", "llm"),
("embedding_id", "embedding"),
("rerank_id", "rerank"),
("reflection_model_id", "llm"), # reflection uses LLM
("emotion_model_id", "llm"), # emotion uses LLM
("emotion_model_id", "llm"), # emotion uses LLM
]
configs_updated = 0
for memory_config in configs:
updated_fields = []
for config_field, workspace_field in model_field_mappings:
config_value = getattr(memory_config, config_field, None)
workspace_value = getattr(workspace, workspace_field, None)
if not config_value and workspace_value:
setattr(memory_config, config_field, workspace_value)
updated_fields.append(config_field)
if updated_fields:
configs_updated += 1
business_logger.debug(
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
)
if configs_updated > 0:
try:
db.commit()
@@ -1032,14 +1041,14 @@ def _fill_workspace_configs_model_defaults(
def _create_default_memory_config(
db: Session,
workspace_id: uuid.UUID,
workspace_name: str,
llm_id: Optional[uuid.UUID] = None,
embedding_id: Optional[uuid.UUID] = None,
rerank_id: Optional[uuid.UUID] = None,
scene_id: Optional[uuid.UUID] = None,
pruning_scene_name: Optional[str] = None,
db: Session,
workspace_id: uuid.UUID,
workspace_name: str,
llm_id: Optional[uuid.UUID] = None,
embedding_id: Optional[uuid.UUID] = None,
rerank_id: Optional[uuid.UUID] = None,
scene_id: Optional[uuid.UUID] = None,
pruning_scene_name: Optional[str] = None,
) -> None:
"""Create a default memory config for a newly created workspace.
@@ -1054,9 +1063,9 @@ def _create_default_memory_config(
pruning_scene_name: Optional pruning scene name取自 ontology_scene.scene_name
"""
from app.models.memory_config_model import MemoryConfig
config_id = uuid.uuid4()
default_config = MemoryConfig(
config_id=config_id,
config_name=f"{workspace_name} 默认配置",
@@ -1070,10 +1079,10 @@ def _create_default_memory_config(
state=True, # Active by default
is_default=True, # Mark as workspace default
)
db.add(default_config)
db.flush() # 使用 flush 而不是 commit让调用者统一提交
business_logger.info(
"Created default memory config for workspace",
extra={
@@ -1084,6 +1093,7 @@ def _create_default_memory_config(
}
)
# ==================== 检查配置相关服务 ====================
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
@@ -1096,19 +1106,19 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
workspace: The workspace to check
"""
from app.models.memory_config_model import MemoryConfig
# Check if default config exists for this workspace
existing_default = db.query(MemoryConfig).filter(
MemoryConfig.workspace_id == workspace.id,
MemoryConfig.is_default == True
).first()
if not existing_default:
# No default config exists, create one
business_logger.info(
f"Workspace {workspace.id} missing default memory config, creating one"
)
# 尝试获取默认场景ID优先教育场景其次情感陪伴场景
default_scene_id = None
try:
@@ -1118,7 +1128,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 尝试中文和英文场景名称
for language in ["zh", "en"]:
@@ -1131,7 +1141,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
)
break
# 如果教育场景不存在,尝试情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
@@ -1145,7 +1155,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
business_logger.warning(
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
)
try:
_create_default_memory_config(
db=db,
@@ -1160,7 +1170,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
business_logger.error(
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
)
# Fill empty model fields for ALL configs in this workspace
_fill_workspace_configs_model_defaults(db, workspace)
@@ -1209,4 +1219,3 @@ def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
business_logger.error(
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
)

View File

@@ -2228,6 +2228,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
TimeFilterUnavailableError,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
@@ -2256,7 +2257,14 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
_redis_client = get_sync_redis_client()
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
for end_user_id in repo.get_users_needing_refresh(_redis_client, batch_size=100):
# Redis 不可用时回退到全量处理
try:
refresh_iter = repo.get_users_needing_refresh(_redis_client, batch_size=100)
except TimeFilterUnavailableError as e:
logger.warning(f"时间轴筛选不可用,回退到全量刷新: {e}")
refresh_iter = repo.get_all_user_ids(batch_size=100)
for end_user_id in refresh_iter:
logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time()
@@ -2605,3 +2613,110 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
# =============================================================================
@celery_app.task(
name="app.tasks.init_interest_distribution_for_users",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
默认生成中文zh兴趣分布数据。
Args:
end_user_ids: 需要检查的用户ID列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__)
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0
failed = 0
skipped = 0
language = "zh"
service = MemoryAgentService()
with get_db_context() as db:
for end_user_id in end_user_ids:
# 存在性检查:缓存有数据则跳过
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
language=language,
)
if cached is not None:
skipped += 1
continue
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
try:
result = await service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=5,
language=language,
)
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
expire=INTEREST_CACHE_EXPIRE,
)
initialized += 1
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
except Exception as e:
failed += 1
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}

View File

@@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7
ENABLE_SINGLE_SESSION=
# File Upload
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024
FILE_PATH=/files
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"

View File

@@ -0,0 +1,34 @@
"""202603101453
Revision ID: fb834419b18f
Revises: 1ac07dc7366f
Create Date: 2026-03-10 14:46:48.038643
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'fb834419b18f'
down_revision: Union[str, None] = '1ac07dc7366f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('end_users', sa.Column('rag_tags', sa.Text(), nullable=True, comment='RAG模式下提取的标签列表JSON格式'))
op.add_column('end_users', sa.Column('rag_personas', sa.Text(), nullable=True, comment='RAG模式下提取的人物形象列表JSON格式'))
op.add_column('end_users', sa.Column('rag_summary_updated_at', sa.DateTime(), nullable=True, comment='RAG摘要/标签/人物形象最后更新时间'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('end_users', 'rag_summary_updated_at')
op.drop_column('end_users', 'rag_personas')
op.drop_column('end_users', 'rag_tags')
# ### end Alembic commands ###

View File

@@ -145,6 +145,8 @@ dependencies = [
"lxml>=4.9.0",
"httpx>=0.28.0",
"modelscope>=1.34.0",
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
"python-magic-bin>=0.4.14; sys_platform=='win32'",
]
[tool.pytest.ini_options]

View File

@@ -135,4 +135,12 @@ export const getExperienceConfig = (share_token: string) => {
'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}`
}
})
}
// Export application
export const appExport = (app_id: string, appName: string, data?: { release_version: string }) => {
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
}
// Import application
export const appImport = (formData: FormData) => {
return request.uploadFile(`/apps/import`, formData)
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 14:00:06
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 10:58:41
* @Last Modified time: 2026-03-12 18:25:06
*/
import { request } from '@/utils/request'
import type {
@@ -118,8 +118,9 @@ export const getChunkInsight = (end_user_id: string) => {
return request.get(`/dashboard/chunk_insight`, { end_user_id })
}
// RAG User Memory - Storage content
export const getRagContent = (end_user_id: string) => {
return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 })
export const getRagContentUrl = '/dashboard/rag_content'
export const getRagContent = (end_user_id: string, page = 1, pagesize = 20) => {
return request.get(getRagContentUrl, { end_user_id, page, pagesize })
}
// Emotion distribution analysis
export const getWordCloud = (end_user_id: string) => {
@@ -224,6 +225,10 @@ export const getConversationDetail = (end_user_id: string, conversation_id: stri
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
return request.post(`/memory/forget-memory/trigger`, data)
}
// RAG type - Refresh RAG user summary and memory insight
export const generateRagProfile = (end_user_id: string) => {
return request.post(`/dashboard/generate_rag_profile`, { end_user_id })
}
/*************** end User Memory APIs ******************************/
/****************** Memory Management APIs *******************************/

View File

@@ -41,12 +41,12 @@ export const deleteCompositeModel = (model_id: string) => {
return request.delete(`/models/composite/${model_id}`)
}
// Create API keys for all matching models by provider
export const updateProviderApiKeys = (data: KeyConfigModalForm) => {
return request.post('/models/provider/apikeys', data)
export const updateProviderApiKeys = (data: KeyConfigModalForm, signal?: AbortSignal) => {
return request.post('/models/provider/apikeys', data, { signal })
}
// Create model API key
export const addModelApiKey = (model_id: string, data: MultiKeyForm) => {
return request.post(`/models/${model_id}/apikeys`, data)
export const addModelApiKey = (model_id: string, data: MultiKeyForm, signal?: AbortSignal) => {
return request.post(`/models/${model_id}/apikeys`, data, { signal })
}
// Delete model API key
export const deleteModelApiKey = (api_key_id: string) => {
@@ -65,10 +65,10 @@ export const addModelPlaza = (model_base_id: string) => {
return request.post(`/models/model_plaza/${model_base_id}/add`)
}
// Create custom model
export const addCustomModel = (data: CustomModelForm) => {
return request.post('/models', data)
export const addCustomModel = (data: CustomModelForm, signal?: AbortSignal) => {
return request.post('/models', data, { signal })
}
// Update custom model
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
return request.put(`/models/${model_base_id}`, data)
export const updateCustomModel = (model_base_id: string, data: CustomModelForm, signal?: AbortSignal) => {
return request.put(`/models/${model_base_id}`, data, { signal })
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-02 15:18:19
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 15:44:42
* @Last Modified time: 2026-03-12 18:36:19
*/
/**
* PageScrollList Component
@@ -60,8 +60,8 @@ interface PageScrollListProps<T, Q = Record<string, unknown>> {
/** Infinite scroll list component with pagination support */
const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
renderItem,
query,
renderItem,
query,
url,
column = 4,
className = '',
@@ -69,68 +69,70 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
}: PageScrollListProps<T, Q>, ref: React.Ref<PageScrollListRef>) => {
/** Expose refresh method to parent component */
useImperativeHandle(ref, () => ({
refresh,
refresh: () => {
pageRef.current = 1;
loadingRef.current = false;
setHasMore(true);
setData([]);
loadMoreData(true);
},
}));
const [loading, setLoading] = useState(false);
const [data, setData] = useState<T[]>([]);
const [page, setPage] = useState(1);
const [hasMore, setHasMore] = useState(true);
const scrollRef = useRef<HTMLDivElement>(null);
const pageRef = useRef(1);
const loadingRef = useRef(false);
const hasMoreRef = useRef(true);
/** Load more data from API with pagination */
const loadMoreData = (flag?: boolean) => {
if (!flag && (loading || !hasMore)) {
return;
}
const loadMoreData = (reset?: boolean) => {
if (loadingRef.current || (!reset && !hasMoreRef.current)) return;
loadingRef.current = true;
setLoading(true);
const currentPage = reset ? 1 : pageRef.current;
request.get(url, {
page: page,
page: currentPage,
pagesize: PAGE_SIZE,
...(query||{}),
...(query || {}),
})
.then((res) => {
const response = res as ApiResponse<T>;
const results = Array.isArray(response.items) ? response.items : Array.isArray(response) ? response as T[] : [];
// Replace data if flag is true, otherwise append
if (flag) {
setData(results);
} else {
setData(data.concat(results));
}
setPage(response.page.page + 1);
pageRef.current = response.page.page + 1;
setData(prev => reset ? results : [...prev, ...results]);
hasMoreRef.current = response.page?.hasnext;
setHasMore(response.page?.hasnext);
setLoading(false);
console.log(`${results.length} more items loaded!`);
})
.catch(() => {
setLoading(false);
hasMoreRef.current = false;
setHasMore(false);
console.error('Failed to load data');
})
.finally(() => {
loadingRef.current = false;
setLoading(false);
// 内容不足以填满容器时,主动继续加载
setTimeout(() => {
const el = scrollRef.current;
console.log(el, el?.scrollHeight, el?.clientHeight, hasMoreRef.current)
if (el && hasMoreRef.current && el.scrollHeight <= el.clientHeight) {
loadMoreData();
}
}, 0);
});
};
/** Reset list to initial state and reload data */
const refresh = () => {
setPage(1);
/** Reset and reload when query parameters change */
const queryKey = JSON.stringify(query);
useEffect(() => {
pageRef.current = 1;
loadingRef.current = false;
hasMoreRef.current = true;
setHasMore(true);
setData([]);
}
loadMoreData(true);
}, [queryKey]);
/** Refresh when query parameters change */
useEffect(() => {
refresh()
}, [query]);
/** Load initial data when list is reset */
useEffect(() => {
if (page === 1 && hasMore && data.length === 0) {
loadMoreData(true);
}
}, [page, hasMore, data])
return (
<>
<div
@@ -140,7 +142,7 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
>
<InfiniteScroll
dataLength={data.length}
next={loadMoreData}
next={() => loadMoreData()}
hasMore={hasMore}
loader={loading && needLoading ? <PageLoading /> : false}
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}

View File

@@ -1370,7 +1370,7 @@ export const en = {
gotoList: 'Return to Application List',
gotoDetail: 'View Details',
dify: 'Dify',
pleaseUploadFile: 'Please upload workflow file',
pleaseUploadFile: 'Please upload file',
},
userMemory: {
userMemory: 'User Memory',
@@ -1820,6 +1820,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
marketRefresh: 'Refresh',
marketConfigBtn: 'Configure',
marketConfigConnection: 'Configure Connection',
marketNoData: 'No Data',
marketNoDataDesc: 'This market currently has no available services',
marketNoSearchResult: 'No Search Results',
marketNoSearchResultDesc: 'No matching services found, please try other keywords',
marketNoServices: 'No MCP Services Available',
marketNotConnected: 'Not Connected to This Market',
marketNoServicesDesc: 'This market currently has no available services',
@@ -2040,6 +2044,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
self_optimization: 'Self Optimization',
process_evolution: 'Process Evolution',
unknown: 'Unknown Node',
notes: 'Sticky Note',
clickToConfigure: 'Click to configure node parameters',
nodeProperties: 'Node Properties',
@@ -2227,6 +2232,12 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
output_variables: 'Output Variables',
refreshTip: 'Sync function signature to code',
},
notes: {
showAuth: 'Show Author',
enterLink: 'Enter Link URL',
placeholder: 'Enter note...',
removeLink: 'Remove Link',
},
name: 'Key',
type: 'Type',
value: 'Value',

View File

@@ -754,7 +754,7 @@ export const zh = {
gotoList: '返回应用列表',
gotoDetail: '查看详情',
dify: 'Dify',
pleaseUploadFile: '请上传工作流文件',
pleaseUploadFile: '请上传文件',
},
table: {
totalRecords: '共 {{total}} 条记录'
@@ -1816,6 +1816,10 @@ export const zh = {
marketRefresh: '刷新',
marketConfigBtn: '配置',
marketConfigConnection: '配置连接',
marketNoData: '暂无数据',
marketNoDataDesc: '该市场暂时没有可用的服务',
marketNoSearchResult: '无搜索结果',
marketNoSearchResultDesc: '未找到匹配的服务,请尝试其他关键词',
marketNoServices: '暂无可用的 MCP 服务',
marketNotConnected: '尚未连接此市场',
marketNoServicesDesc: '该市场暂时没有可用的服务',
@@ -2036,6 +2040,7 @@ export const zh = {
self_optimization: '自我优化',
process_evolution: '流程演化',
unknown: '未知节点',
notes: '便签',
clickToConfigure: '点击配置节点参数',
nodeProperties: '节点属性',
@@ -2226,6 +2231,12 @@ export const zh = {
unknown: {
replaceNodeType: '替换节点'
},
notes: {
showAuth: '显示作者',
enterLink: '输入链接 URL',
placeholder: '输入注释...',
removeLink: '取消链接',
},
name: '键',
type: '类型',
value: '值',

View File

@@ -347,6 +347,23 @@ export const request = {
document.body.removeChild(link);
callback?.()
});
},
getDownloadFile(url: string, fileName: string, data?: unknown, callback?: () => void) {
service.get(url, {
params: paramFilter(data as Record<string, string | number | boolean | ObjectWithPush | null | undefined>),
responseType: "blob",
})
.then(res => {
const link = document.createElement("a");
const blob = new Blob([res as unknown as BlobPart]);
link.style.display = "none";
link.href = URL.createObjectURL(blob);
link.setAttribute("download", decodeURI(fileName || fileName));
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
callback?.()
});
}
};

View File

@@ -11,7 +11,7 @@ import { Button, Space, Input, Form, App } from 'antd';
import Tag, { type TagProps } from './components/Tag'
import RbCard from '@/components/RbCard/Card'
import { getReleaseList, rollbackRelease } from '@/api/application'
import { getReleaseList, rollbackRelease, appExport } from '@/api/application'
import ReleaseModal from './components/ReleaseModal'
import ReleaseShareModal from './components/ReleaseShareModal'
import type { Release, ReleaseModalRef, ReleaseShareModalRef } from './types'
@@ -67,6 +67,9 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
message.success(t('common.operateSuccess'))
})
}
const handleExport = () => {
appExport(data.id, data.name)
}
return (
<div className="rb:flex rb:h-[calc(100vh-64px)]">
<div className="rb:h-full rb:overflow-y-auto rb:w-108 rb:flex-[0_0_auto] rb:border-r rb:border-[#DFE4ED] rb:p-4">
@@ -123,7 +126,7 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
<Space size={10}>
{selectedVersion && <>
{/* <Button>{t('application.exportDSLFile')}</Button> */}
{data?.type !== 'multi_agent' && <Button onClick={handleExport}>{t('common.export')}</Button>}
{data.current_release_id !== selectedVersion.id && <Button onClick={handleRollback}>{t('application.willRollToThisVersion')}</Button>}
<Button type="primary" ghost onClick={() => releaseShareModalRef.current?.handleOpen()}>{t('application.share')}</Button>
</>}

View File

@@ -19,9 +19,8 @@ import deleteIcon from '@/assets/images/delete_hover.svg'
import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types';
import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal'
import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types'
import { deleteApplication } from '@/api/application'
import { deleteApplication, appExport } from '@/api/application'
import CopyModal from './CopyModal'
import { exportToYaml } from '@/utils/yamlExport';
const { Header } = Layout;
@@ -85,16 +84,16 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
* Handle menu item click
*/
const handleClick: MenuProps['onClick'] = ({ key }) => {
if (!application) return
switch (key) {
case 'edit':
applicationModalRef.current?.handleOpen(application as Application)
applicationModalRef.current?.handleOpen(application)
break;
case 'copy':
copyModalRef.current?.handleOpen()
break;
case 'export':
console.log('export', workflowRef?.current?.config)
exportToYaml(workflowRef?.current?.config, application?.name ?`${application?.name}.yml`: undefined)
appExport(application.id, application.name)
break;
case 'delete':
handleDelete()
@@ -153,7 +152,7 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
* Format dropdown menu items
*/
const formatMenuItems = useMemo(() => {
const items = (application?.type === 'workflow' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
const items = (application?.type !== 'multi_agent' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
key,
icon: <img src={menuIcons[key]} className="rb:w-4 rb:h-4 rb:mr-2" />,
label: t(`common.${key}`),

View File

@@ -0,0 +1,256 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-28 14:08:14
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-12 17:19:46
*/
/**
* UploadModal Component
*
* This component provides a modal for uploading workflow files with a multi-step process:
* 1. Upload - Select platform and file
* 2. Complex - Show warnings and errors if any
* 3. SureInfo - Confirm and edit workflow information
* 4. Completed - Show success message and options
*/
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd';
import { useTranslation } from 'react-i18next';
import type { Application, UploadModalRef } from '../types'
import RbModal from '@/components/RbModal'
import UploadFiles from '@/components/Upload/UploadFiles'
import { appImport } from '@/api/application'
/**
* Props for UploadModal component
*/
interface UploadModalProps {
/** Function to refresh the parent component after workflow import */
refresh: () => void;
}
/**
* Steps definition for the upload process
*/
const steps = [
'upload', // Step 1: File upload
'complex', // Step 2: Error/warning display
'completed' // Step 4: Success message
]
/**
* UploadModal component
*
* @param {UploadModalProps} props - Component props
* @param {React.Ref<UploadModalRef>} ref - Ref for imperative methods
*/
const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
refresh
}, ref) => {
const { t } = useTranslation();
// State management
const [visible, setVisible] = useState(false); // Modal visibility
const [form] = Form.useForm<{ file: File[] }>(); // Form instance
const [loading, setLoading] = useState(false); // Loading state
const [current, setCurrent] = useState<number>(0); // Current step
const [appId, setAppId] = useState<string | null>(null); // Imported application ID
const [warnings, setWarnings] = useState<string[]>([])
/**
* Handle modal close
* Resets all states and form fields
*/
const handleClose = () => {
refresh()
setVisible(false);
form.resetFields();
setCurrent(0);
setAppId(null);
setLoading(false);
setWarnings([])
};
/**
* Handle modal open
* Resets form fields and shows modal
*/
const handleOpen = () => {
form.resetFields();
setVisible(true);
};
/**
* Handle save/submit action
* Processes different logic based on current step
*/
const handleSave = () => {
const values = form.getFieldsValue();
switch(current) {
case 0: // Step 1: Upload file
if (!values.file || values.file.length === 0) {
message.warning(t('application.pleaseUploadFile'));
return;
}
const formData = new FormData();
formData.append('file', values.file[0]);
setLoading(true)
// Call import API
appImport(formData)
.then(res => {
const { warnings, app } = res as { warnings: string[]; app: Application };
setAppId(app?.id)
if (warnings.length) {
setCurrent(1)
setWarnings(warnings)
} else {
setCurrent(2)
}
})
.finally(() => setLoading(false));
break;
case 2:
break;
}
};
// Expose methods to parent component via ref
useImperativeHandle(ref, () => ({
handleOpen,
handleClose
}));
/**
* Handle navigation after successful import
* @param {string} type - Navigation type ('detail' or 'list')
*/
const handleJump = (type: string) => {
handleClose();
refresh();
setTimeout(() => {
switch (type) {
case 'detail':
// Open application detail page in new tab
window.open(`/#/application/config/${appId}`, '_blank');
break;
}
}, 100)
};
/**
* Generate modal footer based on current step
*/
const getFooter = useMemo(() => {
switch (current) {
case 0: // Step 1: Upload
return [
<Button key="back" onClick={handleClose}>
{t('common.cancel')}
</Button>,
<Button
key="confirm"
type="primary"
loading={loading}
onClick={handleSave}
>
{t('common.confirm')}
</Button>
];
case 1:
return [
<Button key="back" onClick={() => handleJump('list')}>
{t('application.gotoList')}
</Button>,
<Button
key="submit"
type="primary"
loading={loading}
onClick={() => handleJump('detail')}
>
{t('application.gotoDetail')}
</Button>
]
default:
return null;
}
}, [current, loading]);
return (
<RbModal
title={t('application.import')}
open={visible}
onCancel={handleClose}
okText={t('common.confirm')}
onOk={handleSave}
footer={getFooter}
>
{/* Steps indicator */}
<div className='rb:p-3 rb:bg-[#FBFDFF] rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:mb-3'>
<Steps
labelPlacement="vertical"
size="small"
current={current}
items={steps.map(key => ({ title: t(`application.${key}`) }))}
/>
</div>
{current === 0 &&
<Form
form={form}
layout="vertical"
>
<Form.Item
name="file"
valuePropName="fileList"
noStyle
>
<UploadFiles
isAutoUpload={false}
isCanDrag={true}
fileSize={100}
maxCount={1}
fileType={['yml']}
/>
</Form.Item>
</Form>
}
{/* Step 2: Error/warning display */}
{current === 1 &&
<Flex vertical gap={12}>
{warnings.map((vo, index) => (
<Alert
key={index}
message={<div>{vo}</div>}
type="warning"
showIcon
/>
))}
</Flex>
}
{current === 2 &&
<Result
status="success"
title={t('application.importSuccess')}
subTitle={t('application.importSuccessDesc')}
extra={[
<Button key="back" onClick={() => handleJump('list')}>
{t('application.gotoList')}
</Button>,
<Button
key="submit"
type="primary"
loading={loading}
onClick={() => handleJump('detail')}
>
{t('application.gotoDetail')}
</Button>
]}
/>
}
</RbModal>
);
});
export default UploadModal;

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-28 14:08:14
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-06 12:05:46
* @Last Modified time: 2026-03-12 17:19:33
*/
/**
* UploadWorkflowModal Component
@@ -72,6 +72,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
setFirstFormData(null);
setAppId(null);
setLoading(false);
refresh()
};
/**

View File

@@ -25,6 +25,7 @@ import { getApplicationListUrl, deleteApplication } from '@/api/application'
import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList'
import { formatDateTime } from '@/utils/format';
import UploadWorkflowModal from './components/UploadWorkflowModal'
import UploadModal from './components/UploadModal'
/**
* Application management main component
@@ -37,6 +38,7 @@ const ApplicationManagement: React.FC = () => {
const applicationModalRef = useRef<ApplicationModalRef>(null);
const scrollListRef = useRef<PageScrollListRef>(null)
const uploadWorkflowModalRef = useRef<UploadWorkflowModalRef>(null);
const uploadModalRef = useRef<UploadWorkflowModalRef>(null);
useEffect(() => {
// Convert URLSearchParams to a plain object for easier access
@@ -91,6 +93,8 @@ const ApplicationManagement: React.FC = () => {
case 'thirdParty':
handleImport()
break;
case 'import':
uploadModalRef.current?.handleOpen()
}
}
return (
@@ -121,6 +125,7 @@ const ApplicationManagement: React.FC = () => {
<Dropdown
menu={{ items: [
{ key: 'thirdParty', label: t('application.importWorkflow') },
{ key: 'import', label: t('application.import') },
], onClick: handleClick }}
placement="bottomRight"
>
@@ -186,6 +191,10 @@ const ApplicationManagement: React.FC = () => {
ref={uploadWorkflowModalRef}
refresh={refresh}
/>
<UploadModal
ref={uploadModalRef}
refresh={refresh}
/>
</>
);
};

View File

@@ -233,4 +233,12 @@ export interface UploadData extends WorkflowConfig {
export interface UploadWorkflowModalRef {
/** Open the upload workflow modal */
handleOpen: () => void;
}
/**
* Upload app modal ref interface
*/
export interface UploadModalRef {
/** Open the upload workflow modal */
handleOpen: () => void;
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:28
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 11:31:43
* @Last Modified time: 2026-03-11 15:08:24
*/
/**
* Custom Model Modal
@@ -11,7 +11,7 @@
*/
import { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
import { Form, Input, App, Checkbox } from 'antd';
import { Form, Input, App, Checkbox, Button } from 'antd';
import { useTranslation } from 'react-i18next';
import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types';
@@ -35,6 +35,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
const [isEdit, setIsEdit] = useState(false);
const [form] = Form.useForm<CustomModelForm>();
const [loading, setLoading] = useState(false)
const [abortController, setAbortController] = useState<AbortController | null>(null)
const modelType = Form.useWatch(['type'], form);
const isOmni = Form.useWatch(['is_omni'], form);
@@ -46,6 +47,8 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
/** Close modal and reset state */
const handleClose = () => {
abortController?.abort()
setAbortController(null)
setModel({} as ModelListItem);
form.resetFields();
setLoading(false)
@@ -73,8 +76,10 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
/** Update or create custom model */
const handleUpdate = (data: CustomModelForm) => {
setLoading(true)
const controller = new AbortController()
setAbortController(controller)
const { type, provider, ...rest} = data
const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data)
const res = isEdit ? updateCustomModel(model.id, rest, controller.signal) : addCustomModel(data, controller.signal)
res.then(() => {
refresh?.(isEdit)
@@ -124,15 +129,15 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
useImperativeHandle(ref, () => ({
handleOpen,
}));
console.log('modelType', modelType)
return (
<RbModal
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createCustomModel')}
open={visible}
onCancel={handleClose}
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
onOk={handleSave}
confirmLoading={loading}
footer={[
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.${isEdit ? 'save' : 'create'}`)}</Button>,
]}
>
<Form
form={form}

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:40
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:49:40
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-11 15:12:17
*/
/**
* Key Configuration Modal
@@ -11,7 +11,7 @@
*/
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input, App } from 'antd';
import { Form, Input, App, Button } from 'antd';
import { useTranslation } from 'react-i18next';
import type { KeyConfigModalForm, ProviderModelItem, KeyConfigModalRef, KeyConfigModalProps } from '../types';
@@ -30,9 +30,12 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
const [model, setModel] = useState<ProviderModelItem>({} as ProviderModelItem);
const [form] = Form.useForm<KeyConfigModalForm>();
const [loading, setLoading] = useState(false)
const [abortController, setAbortController] = useState<AbortController | null>(null)
/** Close modal and reset state */
const handleClose = () => {
abortController?.abort()
setAbortController(null)
setModel({} as ProviderModelItem);
form.resetFields();
setLoading(false)
@@ -51,10 +54,13 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
.then((values) => {
setLoading(true)
const controller = new AbortController()
setAbortController(controller)
updateProviderApiKeys({
...values,
provider: model.provider
}).then((res) => {
}, controller.signal).then((res) => {
if (refresh) {
refresh();
}
@@ -81,9 +87,10 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
title={`${model.provider} - ${t('modelNew.keyConfig')}`}
open={visible}
onCancel={handleClose}
okText={t(`common.save`)}
onOk={handleSave}
confirmLoading={loading}
footer={[
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.save`)}</Button>,
]}
>
<Form
form={form}

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:55
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:49:55
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-11 15:11:06
*/
/**
* Multi-Key Configuration Modal
@@ -28,9 +28,12 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
const [form] = Form.useForm<MultiKeyForm>();
const [loading, setLoading] = useState(false)
const [abortController, setAbortController] = useState<AbortController | null>(null)
/** Close modal and refresh parent */
const handleClose = () => {
abortController?.abort()
setAbortController(null)
setModel({} as ModelListItem);
refresh?.()
@@ -60,12 +63,14 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
.validateFields()
.then((values) => {
setLoading(true)
const controller = new AbortController()
setAbortController(controller)
addModelApiKey(model.id, {
...values,
model_config_id: model.id,
model_name: model.name,
provider: model.provider,
}).then(() => {
}, controller.signal).then(() => {
message.success(t('common.saveSuccess'))
form.resetFields();
getData(model)
@@ -98,7 +103,6 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
open={visible}
onCancel={handleClose}
footer={null}
confirmLoading={loading}
>
{model.api_keys && model.api_keys.length > 0 && (
<div className="rb:mb-4">

View File

@@ -395,6 +395,7 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
{t('tool.marketRefresh')}
</Button>
)}
<Input
prefix={<SearchOutlined />}
placeholder={t('tool.marketSearchPlaceholder')}
@@ -402,7 +403,9 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
onChange={(e) => handleSearchChange(e.target.value)}
allowClear
style={{ width: 200 }}
/>
</div>
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
{t('tool.marketConfigBtn')}
@@ -559,9 +562,9 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
<span className="rb:flex-1 rb:font-medium rb:text-[12px] rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
{source.name}
</span>
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
{/* <span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
{source.mcp_count}
</span>
</span> */}
{source.connected && (
<span className="rb:text-green-500 rb:text-[8px] rb:flex-shrink-0"></span>
)}

View File

@@ -132,12 +132,14 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
const request = editVo?.id ? updateTool(editVo.id, newService) : addTool(newService)
request.then((res: any) => {
message.success(t('common.saveSuccess'));
testConnection(res.tool_id || editVo?.id)
.finally(() => {
setLoading(false);
handleClose();
refresh()
})
setLoading(false);
handleClose();
refresh();
// 在后台测试连接,不阻塞用户操作
testConnection(res.tool_id || editVo?.id).catch((err) => {
console.error('测试连接失败:', err);
});
})
.catch(() => {
setLoading(false);

View File

@@ -146,4 +146,5 @@ export interface MarketQuery {
mcp_market_config_id?: string;
page?: number;
pagesize?: number;
keywords?: string;
}

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 17:57:11
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:57:11
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-12 18:00:11
*/
/**
* RAG User Memory Detail View
@@ -13,7 +13,8 @@
import { type FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import clsx from 'clsx'
import { Row, Col, Skeleton } from 'antd'
import { Row, Col, Skeleton, Spin, Flex, Tooltip } from 'antd'
import { LoadingOutlined } from '@ant-design/icons';
import { useParams } from 'react-router-dom'
import aboutUs from '@/assets/images/userMemory/aboutUs.svg'
@@ -26,6 +27,7 @@ import {
getUserProfile,
getTotalRagMemoryCountByUser,
getChunkInsight,
generateRagProfile
} from '@/api/memory'
import Empty from '@/components/Empty'
import ConversationMemory from './components/ConversationMemory'
@@ -133,16 +135,46 @@ const Rag: FC = () => {
})
}
const name = loading.detail ? '' : data?.name && data?.name !== '' ? data.name : id
const [refreshLoading, setRefreshLoading] = useState(false)
const handleRefresh = () => {
if (refreshLoading || !id) return
setRefreshLoading(true)
generateRagProfile(id as string)
.then(() => {
getSummary()
getInsightReport()
})
.finally(() => {
setRefreshLoading(false)
})
}
return (
<Row gutter={[16, 16]} className="rb:pb-6">
<Row gutter={[16, 16]} className="rb:h-full!">
<Col span={8}>
<RbCard>
<RbCard
className="rb:h-[calc(100vh-104px)]!"
bodyClassName="rb:overflow-y-auto! rb:h-full!"
>
<div className="rb:flex rb:items-center">
<div className="rb:flex-[0_0_auto] rb:w-20 rb:h-20 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-20 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF]">{name?.[0]}</div>
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4">
{name}<br/>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
</div>
<Flex>
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4 rb:flex-1">
{name}<br />
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
</div>
<Tooltip title={t('common.refresh')}>
{refreshLoading
? <Spin indicator={<LoadingOutlined spin />} />
: (
<div
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/refresh.svg')] rb:hover:bg-[url('@/assets/images/refresh_hover.svg')]"
onClick={handleRefresh}
></div>
)
}
</Tooltip>
</Flex>
</div>
<div className="rb:flex rb:gap-2 rb:mb-2 rb:flex-wrap rb:mt-6.25">

View File

@@ -1,74 +1,43 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 18:34:04
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 18:34:04
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-12 18:34:52
*/
/**
* Conversation Memory Component
* Displays RAG conversation memory content list
*/
import { type FC, useEffect, useState } from 'react'
import { type FC } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Skeleton, List } from 'antd';
import RbCard from '@/components/RbCard/Card'
import Empty from '@/components/Empty';
import PageScrollList from '@/components/PageScrollList'
import Markdown from '@/components/Markdown'
import {
getRagContent
} from '@/api/memory'
import { getRagContentUrl } from '@/api/memory'
const ConversationMemory:FC = () => {
const ConversationMemory: FC = () => {
const { t } = useTranslation()
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(true)
const [list, setList] = useState<string[]>([])
useEffect(() => {
if (!id) return
getList()
}, [id])
/** Fetch conversation memory list */
const getList = () => {
if (!id) return
setLoading(true)
getRagContent(id).then((res) => {
setList((res as { contents?: [] }).contents || [])
})
.finally(() => {
setLoading(false)
})
}
return (
<RbCard
<RbCard
title={t('userMemory.conversationMemory')}
headerClassName="rb:text-[18px]! rb:leading-[24px]"
bodyClassName="rb:h-[100%]! rb:overflow-hidden rb:py-0!"
bodyClassName="rb:h-[calc(100%-56px)]! rb:overflow-hidden"
className="rb:h-[calc(100vh-104px)]!"
>
{loading
? <Skeleton />
: list.length > 0
? <List
dataSource={list}
grid={{ gutter: 12, column: 1 }}
renderItem={(item, index) => (
<List.Item>
<div
key={index}
className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:mt-2 rb:text-gray-800 rb:text-sm"
>
<Markdown content={item} />
</div>
</List.Item>
)}
/>
: <Empty className="rb:h-full" />
}
<PageScrollList<string>
url={getRagContentUrl}
query={{ end_user_id: id }}
column={1}
renderItem={(item) => (
<div className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:text-gray-800 rb:text-sm">
<Markdown content={item} />
</div>
)}
className="rb:h-full!"
// className="rb:h-[calc(100%-24px)]!"
/>
</RbCard>
)
}
export default ConversationMemory
export default ConversationMemory

View File

@@ -1,8 +1,8 @@
import type { FC } from 'react';
import { Select } from 'antd';
import { Select, Divider } from 'antd';
// import { Node } from '@antv/x6';
import type { GraphRef } from '../types'
import { PlusOutlined, MinusOutlined } from '@ant-design/icons'
import { PlusOutlined, MinusOutlined, FileAddOutlined } from '@ant-design/icons'
interface CanvasToolbarProps {
miniMapRef: React.RefObject<HTMLDivElement>;
@@ -14,6 +14,7 @@ interface CanvasToolbarProps {
canRedo: boolean;
onUndo: () => void;
onRedo: () => void;
addNotes: () => void;
}
const CanvasToolbar: FC<CanvasToolbarProps> = ({
@@ -26,6 +27,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
// canRedo,
// onUndo,
// onRedo,
addNotes,
}) => {
// 整理布局函数
/*
@@ -152,7 +154,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
{/* 小地图 */}
<div ref={miniMapRef} className="rb:absolute rb:bottom-15 rb:right-8 rb:z-1000 rb:rounded-lg rb:overflow-hidden"></div>
{/* 缩放控制按钮 */}
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:gap-4 rb:z-1000">
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:items-center rb:gap-4 rb:z-1000">
<MinusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(-0.1)} />
<Select
value={Math.round(zoomLevel * 100)}
@@ -182,6 +184,8 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
size="small"
/>
<PlusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(0.1)} />
<Divider type="vertical" className="rb:h-4" />
<FileAddOutlined onClick={addNotes} />
</div>
</>
);

View File

@@ -0,0 +1,108 @@
import { useEffect, useRef } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { FORMAT_TEXT_COMMAND, $getSelection, $isRangeSelection, $setSelection, $isTextNode, type BaseSelection } from 'lexical';
import { $patchStyleText } from '@lexical/selection';
import { INSERT_UNORDERED_LIST_COMMAND, REMOVE_LIST_COMMAND, ListNode } from '@lexical/list';
import { TOGGLE_LINK_COMMAND, LinkNode } from '@lexical/link';
import { $getNearestNodeOfType } from '@lexical/utils';
export const NOTE_FORMAT_EVENT = 'note:format';
export interface FormatState {
bold: boolean;
italic: boolean;
strikethrough: boolean;
list: boolean;
fontSize?: number;
linkUrl?: string | null;
}
const NoteFormatPlugin = ({ nodeId, onFormatChange, fontSize = 12 }: { nodeId: string; fontSize?: number; onFormatChange?: (state: FormatState) => void }) => {
const [editor] = useLexicalComposerContext();
const savedSelection = useRef<BaseSelection | null>(null);
useEffect(() => {
return editor.registerUpdateListener(({ editorState }) => {
editorState.read(() => {
const selection = $getSelection();
if (!$isRangeSelection(selection)) return;
savedSelection.current = selection.clone();
const anchorNode = selection.anchor.getNode();
const style = 'getStyle' in anchorNode ? (anchorNode as { getStyle(): string }).getStyle() : '';
const match = style.match(/font-size:\s*([\d.]+)px/);
const nodeFontSize = match ? Number(match[1]) : fontSize;
const linkNode = $getNearestNodeOfType(anchorNode, LinkNode);
onFormatChange?.({
bold: selection.hasFormat('bold'),
italic: selection.hasFormat('italic'),
strikethrough: selection.hasFormat('strikethrough'),
list: !!$getNearestNodeOfType(anchorNode, ListNode),
...(nodeFontSize ? { fontSize: nodeFontSize } : {}),
linkUrl: linkNode ? linkNode.getURL() : null,
});
});
});
}, [editor, onFormatChange]);
useEffect(() => {
const handler = (e: Event) => {
const { id, format, value } = (e as CustomEvent).detail;
if (id !== nodeId) return;
const sel = savedSelection.current;
const hasSelection = $isRangeSelection(sel) && !sel.isCollapsed();
if (format === 'link' && value === null) {
// remove link: select the entire LinkNode first
editor.focus(() => {
editor.update(() => {
const s = $getSelection();
const anchorNode = $isRangeSelection(s)
? s.anchor.getNode()
: savedSelection.current && $isRangeSelection(savedSelection.current)
? savedSelection.current.anchor.getNode()
: null;
const linkNode = anchorNode ? $getNearestNodeOfType(anchorNode, LinkNode) : null;
if (linkNode) {
const children = linkNode.getChildren();
if (children.length > 0) {
const first = children[0];
const last = children[children.length - 1];
if ($isTextNode(first) && $isTextNode(last)) {
const range = first.select(0, 0);
range.focus.set(last.getKey(), last.getTextContentSize(), 'text');
}
}
}
});
editor.dispatchCommand(TOGGLE_LINK_COMMAND, null);
});
} else if (format === 'list') {
editor.focus(() => {
if (sel) editor.update(() => $setSelection(sel));
editor.dispatchCommand(value ? INSERT_UNORDERED_LIST_COMMAND : REMOVE_LIST_COMMAND, undefined);
editor.update(() => $setSelection(null));
});
} else if (hasSelection) {
editor.focus(() => {
editor.update(() => $setSelection(sel));
if (format === 'bold' || format === 'italic' || format === 'strikethrough') {
editor.dispatchCommand(FORMAT_TEXT_COMMAND, format);
} else if (format === 'link') {
editor.dispatchCommand(TOGGLE_LINK_COMMAND, value as string | null);
} else if (format === 'fontSize') {
editor.update(() => {
$setSelection(sel);
$patchStyleText(sel!, { 'font-size': `${value}px` });
});
}
editor.update(() => $setSelection(null));
});
}
};
window.addEventListener(NOTE_FORMAT_EVENT, handler);
return () => window.removeEventListener(NOTE_FORMAT_EVENT, handler);
}, [editor, nodeId]);
return null;
};
export default NoteFormatPlugin;

View File

@@ -0,0 +1,74 @@
import { type FC, useState } from 'react';
import { createPortal } from 'react-dom';
import { useTranslation } from 'react-i18next';
import { Flex, Button, Input } from 'antd';
import { EditOutlined, DisconnectOutlined } from '@ant-design/icons';
const POPOVER_STYLE: React.CSSProperties = {
position: 'fixed',
zIndex: 1000,
background: '#fff',
border: '1px solid #e5e7eb',
borderRadius: 8,
boxShadow: '0 2px 8px rgba(0,0,0,0.12)',
whiteSpace: 'nowrap',
};
interface LinkPopoverProps {
url: string;
rect: DOMRect;
onEdit: () => void;
onRemove: () => void;
}
export const LinkPopover: FC<LinkPopoverProps> = ({ url, rect, onEdit, onRemove }) => {
const { t } = useTranslation();
return createPortal(
<div
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '4px 10px', fontSize: 12 }}
onMouseDown={e => e.stopPropagation()}
>
<Flex align="center" gap={8}>
<a href={url} target="_blank" rel="noreferrer" style={{ color: '#2563eb', maxWidth: 160, overflow: 'hidden', textOverflow: 'ellipsis', display: 'inline-block' }}>
{url}
</a>
<Button size="small" type="text" icon={<EditOutlined />} onClick={onEdit}>{t('common.edit')}</Button>
<Button size="small" type="text" icon={<DisconnectOutlined />} onClick={onRemove}>{t('workflow.config.notes.removeLink')}</Button>
</Flex>
</div>,
document.body
);
};
interface EditLinkPopoverProps {
rect: DOMRect;
initialUrl: string;
onConfirm: (url: string) => void;
}
export const EditLinkPopover: FC<EditLinkPopoverProps> = ({ rect, initialUrl, onConfirm }) => {
const { t } = useTranslation();
const [url, setUrl] = useState(initialUrl);
const confirm = () => onConfirm(url);
return createPortal(
<div
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '8px' }}
onMouseDown={e => e.stopPropagation()}
>
<Flex gap={8}>
<Input
size="small"
className="rb:w-60!"
placeholder={t('workflow.config.notes.enterLink')}
value={url}
onChange={e => setUrl(e.target.value)}
onKeyDown={e => e.stopPropagation()}
onPressEnter={confirm}
autoFocus
/>
<Button size="small" type="primary" onClick={confirm}>{t('common.confirm')}</Button>
</Flex>
</div>,
document.body
);
};

View File

@@ -0,0 +1,184 @@
import { type FC, useState, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { LexicalComposer } from '@lexical/react/LexicalComposer';
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin';
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
import { ListPlugin } from '@lexical/react/LexicalListPlugin';
import { LinkPlugin } from '@lexical/react/LexicalLinkPlugin';
import { ListNode, ListItemNode } from '@lexical/list';
import { LinkNode } from '@lexical/link';
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { useEffect, useRef } from 'react';
import NoteFormatPlugin from './NoteFormatPlugin';
import type { FormatState } from './NoteFormatPlugin';
import { LinkPopover, EditLinkPopover } from './NoteLinkPopovers';
const theme = {
paragraph: 'editor-paragraph',
text: {
bold: 'editor-text-bold',
italic: 'editor-text-italic',
strikethrough: 'note-text-strikethrough',
},
list: { ul: 'note-list-ul', listitem: 'note-list-item' },
link: 'note-link',
};
const NOTE_NODES = [ListNode, ListItemNode, LinkNode];
const NOTE_STYLES = `
.editor-text-bold { font-weight: bold; }
.editor-text-italic { font-style: italic; }
.note-text-strikethrough { text-decoration: line-through; }
.note-list-ul { list-style-type: disc; padding-left: 1.2em; margin: 0; }
.note-list-item { margin: 2px 0; }
.note-link { color: #2563eb; text-decoration: underline; cursor: pointer; }
`;
const NoteInitPlugin: FC<{ value: string }> = ({ value }) => {
const [editor] = useLexicalComposerContext();
const initialized = useRef(false);
useEffect(() => {
if (initialized.current || !value) return;
initialized.current = true;
try {
const parsed = JSON.parse(value);
if (parsed?.root) {
const state = editor.parseEditorState(JSON.stringify(parsed));
editor.setEditorState(state);
return;
}
} catch {}
}, [editor, value]);
return null;
};
interface NoteEditorProps {
nodeId: string;
value: string;
fontSize?: number;
onChange: (val: string) => void;
onFormatChange?: (state: FormatState) => void;
}
const NoteEditor: FC<NoteEditorProps> = ({ nodeId, value, fontSize = 12, onChange, onFormatChange }) => {
const { t } = useTranslation();
const [linkState, setLinkState] = useState<{ url: string; rect: DOMRect } | null>(null);
const [editLinkRect, setEditLinkRect] = useState<{ url: string; rect: DOMRect } | null>(null);
const removingLink = useRef(false);
useEffect(() => {
if (!linkState) return;
const handler = () => setLinkState(null);
window.addEventListener('mousedown', handler);
return () => window.removeEventListener('mousedown', handler);
}, [!!linkState]);
useEffect(() => {
const handler = (e: Event) => {
const { id, url, rect: passedRect } = (e as CustomEvent).detail;
if (id !== nodeId) return;
if (passedRect) {
setEditLinkRect({ url: url || '', rect: passedRect });
return;
}
const sel = window.getSelection();
if (sel && sel.rangeCount > 0) {
const r = sel.getRangeAt(0).getBoundingClientRect();
if (r.width > 0 || r.height > 0) { setEditLinkRect({ url: url || '', rect: r }); return; }
}
const linkEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
const rect = linkEl?.getBoundingClientRect() ?? new DOMRect(window.innerWidth / 2, 200, 0, 0);
setEditLinkRect({ url: url || '', rect });
};
window.addEventListener('note:edit-link', handler);
return () => window.removeEventListener('note:edit-link', handler);
}, [nodeId]);
const handleFormatChange = useCallback((state: FormatState) => {
onFormatChange?.(state);
if (state.linkUrl) {
requestAnimationFrame(() => {
if (removingLink.current) { removingLink.current = false; return; }
const sel = window.getSelection();
if (sel && sel.rangeCount > 0) {
const rect = sel.getRangeAt(0).getBoundingClientRect();
if (rect.width > 0 || rect.height > 0) {
setLinkState({ url: state.linkUrl!, rect });
return;
}
}
// fallback: find the link element in the correct editor
const editorEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
if (editorEl) {
setLinkState({ url: state.linkUrl!, rect: editorEl.getBoundingClientRect() });
}
});
} else {
setLinkState(null);
}
}, [onFormatChange]);
return (
<>
<style>{NOTE_STYLES}</style>
<LexicalComposer initialConfig={{ namespace: `note-${nodeId}`, theme, nodes: NOTE_NODES, onError: console.error }}>
<div style={{ position: 'relative' }} data-note-id={nodeId}>
<RichTextPlugin
contentEditable={
<ContentEditable
style={{ minHeight: 60, outline: 'none', resize: 'none', fontSize: '12px', lineHeight: '18px', color: '#374151', overflow: 'auto', cursor: 'auto' }}
/>
}
placeholder={
<div style={{ position: 'absolute', top: 0, left: 0, color: '#9CA3AF', lineHeight: '18px', pointerEvents: 'none' }}>
{t('workflow.config.notes.placeholder')}
</div>
}
ErrorBoundary={LexicalErrorBoundary}
/>
<HistoryPlugin />
<ListPlugin />
<LinkPlugin />
<OnChangePlugin onChange={(editorState) => onChange(JSON.stringify(editorState.toJSON()))} />
<NoteInitPlugin value={value} />
<NoteFormatPlugin nodeId={nodeId} fontSize={fontSize} onFormatChange={handleFormatChange} />
{editLinkRect && (
<EditLinkPopover
rect={editLinkRect.rect}
initialUrl={editLinkRect.url}
onConfirm={(url) => {
removingLink.current = true;
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: url || null } }));
setEditLinkRect(null);
}}
/>
)}
{linkState && (
<LinkPopover
url={linkState.url}
rect={linkState.rect}
onEdit={() => {
removingLink.current = true;
const { rect, url } = linkState;
setLinkState(null);
setEditLinkRect({ url, rect });
}}
onRemove={() => {
removingLink.current = true;
setLinkState(null);
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: null } }));
}}
/>
)}
</div>
</LexicalComposer>
</>
);
};
export default NoteEditor;

View File

@@ -0,0 +1,163 @@
import { type FC } from 'react';
import { Flex, Dropdown, type MenuProps, Switch, Button, Divider } from 'antd';
import { UnorderedListOutlined, BoldOutlined, ItalicOutlined, StrikethroughOutlined, LinkOutlined, DashOutlined } from '@ant-design/icons';
import { Node } from '@antv/x6';
import { useTranslation } from 'react-i18next'
import { THEME_MAP } from '../../../constant';
const FONT_SIZES = [
{ label: '小', value: 12 },
{ label: '中', value: 14 },
{ label: '大', value: 16 },
];
interface NoteNodeToolbarProps {
node: Node;
onFormat: (type: string, value?: unknown) => void;
toolConfig: Record<string, number | boolean>;
nodeId: string;
}
const NoteNodeToolbar: FC<NoteNodeToolbarProps> = ({ node, onFormat, toolConfig, nodeId }) => {
const data = node?.getData() || {};
const { t } = useTranslation();
const colorItems: MenuProps['items'] = Object.entries(THEME_MAP).map(([key, theme]) => ({
key,
label: (
<div
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
style={{ background: theme.bg }}
onClick={() => onFormat('color', key)}
/>
),
}));
const fontSizeItems: MenuProps['items'] = FONT_SIZES.map(({ label, value }) => ({
key: value,
label: <span onClick={() => onFormat('fontSize', value)}>{label}</span>,
}));
const currentFontSize = FONT_SIZES.find(f => f.value === toolConfig.fontSize)?.label ?? '小';
const handleClick: MenuProps['onClick'] = (e) => {
switch (e.key) {
case 'delete':
node.remove()
break;
case 'copy':
break;
}
}
const handleChange = (type: string) => {
let show_author = data.config.show_author.defaultValue
if(type === 'showAuth'){
show_author = !show_author
}
node.setData({
...data,
config: {
...data.config,
show_author: {
...data.config.show_author,
defaultValue: show_author
}
}
})
}
return (
<Flex
align="center"
gap={8}
className="rb:absolute rb:-top-11 rb:left-1/2 rb:-translate-x-1/2 rb:bg-white rb:z-10 rb:whitespace-nowrap rb:rounded-lg rb:py-1! rb:px-3!"
onClick={e => e.stopPropagation()}
>
{/* Color picker */}
<Dropdown menu={{ items: colorItems }} trigger={['click']}>
<div
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
style={{ background: THEME_MAP[data.bgColor]?.bg || THEME_MAP.blue.bg }}
/>
</Dropdown>
<Divider type="vertical" />
{/* Font size */}
<Dropdown menu={{ items: fontSizeItems }} trigger={['click']}>
<Flex align="center" gap={4} className="rb:cursor-pointer rb:text-xs rb:text-gray-600 rb:select-none">
<span className="rb:text-xs">Aa</span>
<span className="rb:text-xs">{currentFontSize}</span>
</Flex>
</Dropdown>
<Divider type="vertical" />
{/* Bold */}
<Button
type={toolConfig.bold ? 'primary' : 'text'}
icon={<BoldOutlined />}
onClick={() => onFormat('bold')}
/>
{/* Italic */}
<Button
type={toolConfig.italic ? 'primary' : 'text'}
icon={<ItalicOutlined />}
onClick={() => onFormat('italic')}
/>
{/* Strikethrough */}
<Button
type={toolConfig.strikethrough ? 'primary' : 'text'}
icon={<StrikethroughOutlined />}
onClick={() => onFormat('strikethrough')}
/>
{/* Link */}
<Button
type={toolConfig.link ? 'primary' : 'text'}
icon={<LinkOutlined />}
onClick={() => {
const sel = window.getSelection();
const rect = sel && sel.rangeCount > 0 ? sel.getRangeAt(0).getBoundingClientRect() : undefined;
window.dispatchEvent(new CustomEvent('note:edit-link', { detail: { id: nodeId, url: '', rect } }));
}}
/>
{/* List */}
<Button
type={toolConfig.list ? 'primary' : 'text'}
icon={<UnorderedListOutlined />}
onClick={() => onFormat('list')}
/>
<Divider type="vertical" />
<Dropdown
menu={{
items: [
// { key: 'copy', label: t('common.copy') },
{
key: 'showAuth',
label: <Flex align="center" gap={24}>
{t('workflow.config.notes.showAuth')}
<Switch
size="small"
checked={data.config.show_author.defaultValue}
onChange={() => handleChange('showAuth')}
/>
</Flex>
},
{ key: 'delete', label: <Flex>{t('common.delete')}</Flex> },
],
onClick: handleClick
}}
>
<DashOutlined />
</Dropdown>
</Flex>
);
};
export default NoteNodeToolbar;

View File

@@ -0,0 +1,155 @@
import { useRef, useState } from 'react';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { Flex } from 'antd';
import NoteEditor from './NoteEditor';
import NoteNodeToolbar from './NoteNodeToolbar';
import { THEME_MAP } from '../../../constant'
const MIN_W = 240;
const MIN_H = 120;
const NoteNode: ReactShapeConfig['component'] = ({ node }) => {
const data = node?.getData() || {};
const nodeId = node?.id || '';
const startRef = useRef<{ x: number; y: number; w: number; h: number } | null>(null);
const [toolConfig, setToolConfig] = useState({
fontSize: 12,
bold: false,
italic: false,
strikethrough: false,
list: false,
})
const handleFormat = (type: string, value?: unknown) => {
console.log('handleFormat', type, value)
if (type === 'color') {
node?.setData({
...data,
config: {
...data.config,
theme: {
...data.config.theme,
defaultValue: value
}
}
});
} else if (type === 'fontSize') {
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'fontSize', value } }));
} else if (type === 'link') {
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: value || null } }));
} else if (type === 'list') {
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'list', value: !toolConfig.list } }));
} else {
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: type } }));
}
setToolConfig(prev => ({ ...prev, [type]: value || !prev[type as unknown as keyof typeof toolConfig] }))
};
const onResizeMouseDown = (e: React.MouseEvent) => {
e.stopPropagation();
e.preventDefault();
const size = node?.getSize();
if (!size) return;
startRef.current = { x: e.clientX, y: e.clientY, w: size.width, h: size.height };
const onMouseMove = (ev: MouseEvent) => {
if (!startRef.current) return;
const w = Math.max(MIN_W, startRef.current.w + ev.clientX - startRef.current.x);
const h = Math.max(MIN_H, startRef.current.h + ev.clientY - startRef.current.y);
node?.setData({
...data,
config: {
...data.config,
width: {
...data.config.width,
defaultValue: w
},
height: {
...data.config.height,
defaultValue: h
}
}
});
node?.prop('size', { width: w, height: h });
};
const onMouseUp = () => {
startRef.current = null;
window.removeEventListener('mousemove', onMouseMove);
window.removeEventListener('mouseup', onMouseUp);
};
window.addEventListener('mousemove', onMouseMove);
window.addEventListener('mouseup', onMouseUp);
};
const updateText = (value: string) => {
node.setData({
...data,
config: {
...data.config,
text: {
...data.config.text,
defaultValue: value
}
}
})
}
const theme = THEME_MAP[data.config?.theme?.defaultValue || 'blue'] || THEME_MAP['blue']
return (
<div
className="rb:relative rb:h-full rb:w-full rb:rounded-2xl rb:border"
style={{
background: theme.bg,
borderColor: data.isSelected ? theme.outer : theme.border,
}}
>
<div className="rb:h-4 rb:rounded-tl-2xl rb:rounded-tr-2xl"
style={{
background: theme.title
}}
></div>
{data.isSelected && <NoteNodeToolbar node={node!} nodeId={nodeId} toolConfig={toolConfig} onFormat={handleFormat} />}
<div
className="rb:w-full rb:h-[calc(100%-36px)] rb:p-2.5 rb:overflow-auto"
onMouseDown={e => {
e.stopPropagation()
node?.setData({ ...node.getData(), isSelected: true })
}}
onWheel={e => e.stopPropagation()}
>
<NoteEditor
nodeId={nodeId}
value={data.config.text.defaultValue || ''}
fontSize={toolConfig.fontSize}
onChange={updateText}
onFormatChange={(state) => setToolConfig(prev => ({ ...prev, ...state }))}
/>
</div>
<Flex align="center" justify="space-between" className="rb:pl-2.5! rb:pr-1!">
<div className="rb:text-[12px] rb:text-[#5B6167]">
{data.config.show_author.defaultValue
? data.config.author.defaultValue
: undefined
}
</div>
{/* <div className="rb:size-4 rb:border-b-[4px] rb:border-r-[4px] rb:border-[#EBEBEB] rb:rounded-2xl"></div> */}
<div
onMouseDown={onResizeMouseDown}
>
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
<path fillRule="evenodd" clipRule="evenodd" d="M12 9.75V6H13.5V9.75C13.5 11.8211 11.8211 13.5 9.75 13.5H6V12H9.75C10.9926 12 12 10.9926 12 9.75Z" fill="black" fillOpacity="0.16"></path>
</svg>
</div>
</Flex>
</div>
);
};
export default NoteNode;

View File

@@ -2,13 +2,14 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-11 12:07:20
* @Last Modified time: 2026-03-09 13:41:19
*/
import LoopNode from './components/Nodes/LoopNode';
import NormalNode from './components/Nodes/NormalNode';
import ConditionNode from './components/Nodes/ConditionNode';
import GroupStartNode from './components/Nodes/GroupStartNode';
import AddNode from './components/Nodes/AddNode'
import NoteNode from './components/Nodes/NoteNode';
import type { PortMetadata, GroupMetadata } from '@antv/x6/lib/model/port';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
@@ -525,6 +526,73 @@ export const nodeLibrary: NodeLibrary[] = [
// ]
// },
];
export const THEME_MAP: Record<string, { outer: string; title: string; bg: string; border: string }> = {
blue: {
outer: '#2E90FA',
title: '#D1E9FF',
bg: '#EFF8FF',
border: '#84CAFF',
},
cyan: {
outer: '#06AED4',
title: '#CFF9FE',
bg: '#ECFDFF',
border: '#67E3F9',
},
green: {
outer: '#16B364',
title: '#D3F8DF',
bg: '#EDFCF2',
border: '#73E2A3',
},
yellow: {
outer: '#EAAA08',
title: '#FEF7C3',
bg: '#FEFBE8',
border: '#FDE272',
},
pink: {
outer: '#EE46BC',
title: '#FCE7F6',
bg: '#FDF2FA',
border: '#FAA7E0',
},
violet: {
outer: '#875BF7',
title: '#ECE9FE',
bg: '#F5F3FF',
border: '#C3B5FD',
},
}
export const notesConfig = {
type: "notes", icon: templateRenderingIcon,
config: {
text: {
type: 'define',
},
theme: {
type: 'define',
defaultValue: 'blue',
},
width: {
type: 'define',
width: 240,
},
height: {
type: 'define',
height: 120,
},
author: {
type: 'define',
},
show_author: {
type: 'define',
defaultValue: true
}
}
}
export const unknownNode = {
type: 'unknown',
icon: unknownIcon
@@ -576,6 +644,12 @@ export const nodeRegisterLibrary: ReactShapeConfig[] = [
height: 44,
component: AddNode,
},
{
shape: 'notes-node',
width: nodeWidth,
height: 120,
component: NoteNode,
},
];
/**
@@ -801,6 +875,11 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
groups: {left: { position: 'left', markup: portMarkup, attrs: portAttrs }},
items: [{ group: 'left' }],
},
},
notes: {
width: nodeWidth,
height: 120,
shape: 'notes-node',
}
}

View File

@@ -12,9 +12,10 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
import { register } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode } from '../constant';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode, notesConfig } from '../constant';
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
import { useUser } from '@/store/user';
/**
* Props for useWorkflowGraph hook
@@ -64,6 +65,8 @@ export interface UseWorkflowGraphReturn {
chatVariables: ChatVariable[];
/** Function to update chat variables */
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
handleAddNotes: () => void;
}
/**
@@ -80,6 +83,7 @@ export const useWorkflowGraph = ({
const { id } = useParams();
const { message } = App.useApp();
const { t } = useTranslation()
const { user } = useUser();
// Refs
const graphRef = useRef<Graph>();
@@ -128,7 +132,7 @@ export const useWorkflowGraph = ({
if (nodes.length) {
const nodeList = nodes.map(node => {
const { id, type, name, position, config = {} } = node
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, noteNode] }]
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, notesConfig] }]
.flatMap(category => category.nodes)
.find(n => n.type === type)
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
@@ -197,6 +201,13 @@ export const useWorkflowGraph = ({
data: { ...node, ...nodeLibraryConfig},
...position,
}
if (type === 'notes') {
const w = config.width;
const h = config.height;
if (w) nodeConfig.width = w as number;
if (h) nodeConfig.height = h as number;
}
// Generate ports dynamically for if-else node based on cases
if (type === 'if-else' && config.cases && Array.isArray(config.cases)) {
@@ -461,11 +472,12 @@ export const useWorkflowGraph = ({
*/
const nodeClick = ({ node }: { node: Node }) => {
// Ignore add-node type node clicks
if (node.getData()?.type === 'add-node' || node.getData().type === 'break' || node.getData().type === 'cycle-start') {
const nodeData = node.getData()
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') {
setSelectedNode(null)
return;
}
const nodes = graphRef.current?.getNodes();
nodes?.forEach(vo => {
@@ -478,10 +490,12 @@ export const useWorkflowGraph = ({
}
});
node.setData({
...node.getData(),
...nodeData,
isSelected: true,
});
setSelectedNode(node);
if (nodeData.type !== 'notes') {
setSelectedNode(node);
}
};
/**
* Handle edge click event
@@ -859,8 +873,31 @@ export const useWorkflowGraph = ({
init();
window.addEventListener('resize', handleResize);
const handleNoteKeydown = (e: KeyboardEvent) => {
if (!graphRef.current) return;
const selectedNote = graphRef.current.getNodes().find(n => n.getData()?.isSelected && n.getData()?.type === 'notes');
if (!selectedNote) return;
const isMeta = e.ctrlKey || e.metaKey;
if (e.key === 'Delete' || e.key === 'Backspace') {
// Only delete node when editor is not focused on text
const active = document.activeElement;
if (active && (active as HTMLElement).isContentEditable) return;
deleteEvent();
} else if (isMeta && e.key === 'c') {
copyEvent();
} else if (isMeta && e.key === 'v') {
parseEvent();
} else if (isMeta && e.key === 'd') {
e.preventDefault();
deleteEvent();
}
};
window.addEventListener('keydown', handleNoteKeydown);
return () => {
window.removeEventListener('resize', handleResize);
window.removeEventListener('keydown', handleNoteKeydown);
graphRef.current?.dispose();
};
}, []);
@@ -884,7 +921,7 @@ export const useWorkflowGraph = ({
.flatMap(category => category.nodes)
.find(n => n.type === dragData.type);
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
// Create clean node data, only keep necessary fields
const cleanNodeData = {
id: `${dragData.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
@@ -1103,6 +1140,32 @@ export const useWorkflowGraph = ({
})
}
const handleAddNotes = () => {
if (!graphRef.current) return;
const nodeConfig: NodeProperties = JSON.parse(JSON.stringify(notesConfig));
nodeConfig.config = {
...nodeConfig.config,
author: { type: 'define', defaultValue: user?.username || '' },
};
const cleanNodeData = {
id: `notes_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
name: t('workflow.notes'),
...nodeConfig,
};
const container = graphRef.current.container;
const nodeW = graphNodeLibrary.notes?.width || nodeWidth;
const nodeH = graphNodeLibrary.notes?.height || 100;
const rect = container.getBoundingClientRect();
const center = graphRef.current.clientToLocal(rect.left + rect.width / 2, rect.top + rect.height / 2);
graphRef.current.addNode({
...(graphNodeLibrary.notes || graphNodeLibrary.default),
x: center.x - nodeW / 2,
y: center.y - nodeH / 2,
id: cleanNodeData.id,
data: { ...cleanNodeData },
});
}
return {
config,
setConfig,
@@ -1120,6 +1183,7 @@ export const useWorkflowGraph = ({
parseEvent,
handleSave,
chatVariables,
setChatVariables
setChatVariables,
handleAddNotes
};
};

View File

@@ -38,7 +38,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
parseEvent,
handleSave,
chatVariables,
setChatVariables
setChatVariables,
handleAddNotes
} = useWorkflowGraph({ containerRef, miniMapRef });
const onDragOver = (event: React.DragEvent) => {
@@ -95,6 +96,7 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
canRedo={canRedo}
onUndo={onUndo}
onRedo={onRedo}
addNotes={handleAddNotes}
/>
</div>