Merge branch 'release/v0.2.7' into feature/tool_yjp
This commit is contained in:
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -65,7 +65,7 @@ celery_app.conf.update(
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=True,
|
||||
enable_utc=False,
|
||||
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
@@ -114,6 +114,7 @@ celery_app.conf.update(
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
import io
|
||||
from typing import Optional, Annotated
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -1010,3 +1013,57 @@ def get_workspace_api_statistics(
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
release_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||
"""
|
||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||
encoded = quote(filename, safe=".")
|
||||
yaml_bytes = yaml_str.encode("utf-8")
|
||||
file_stream = io.BytesIO(yaml_bytes)
|
||||
file_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type="application/octet-stream; charset=utf-8",
|
||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||
"Content-Length": str(len(yaml_bytes))}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw = (await file.read()).decode("utf-8")
|
||||
dsl = yaml.safe_load(raw)
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
@@ -65,13 +65,18 @@ async def get_mcp_servers(
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
@@ -141,13 +146,18 @@ async def get_operational_mcp_servers(
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
@@ -199,13 +209,18 @@ async def get_mcp_server(
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
@@ -263,7 +278,7 @@ async def get_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
@@ -296,7 +311,7 @@ async def get_mcp_market_config_by_mcp_market_id(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
@@ -325,7 +340,7 @@ async def update_mcp_market_config(
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
@@ -382,7 +397,7 @@ async def delete_mcp_market_config(
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
@@ -156,9 +157,13 @@ async def get_workspace_end_users(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发隐性记忆按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_interest_distribution_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发隐性记忆按需初始化任务失败(不影响主流程): {e}")
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
@@ -398,14 +403,15 @@ def get_current_user_rag_total_num(
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
获取当前宿主知识库中的chunk内容(分页)
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@@ -418,26 +424,18 @@ async def get_chunk_summary_tag(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
"summary": "用户摘要",
|
||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||
"personas": ["产品设计师", ...],
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -445,9 +443,8 @@ async def get_chunk_summary_tag(
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
@@ -458,24 +455,57 @@ async def get_chunk_insight(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
读取RAG洞察报告(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
"insight": "总体概述",
|
||||
"behavior_pattern": "行为模式",
|
||||
"key_findings": "关键发现",
|
||||
"growth_trajectory": "成长轨迹",
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
class GenerateRagProfileRequest(BaseModel):
|
||||
end_user_id: str = Field(..., description="宿主ID")
|
||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||
max_tags: int = Field(10, description="最大标签数量")
|
||||
|
||||
|
||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||
async def generate_rag_profile(
|
||||
body: GenerateRagProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||
每次请求都会重新生成,覆盖已有数据。
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||
|
||||
data = await memory_dashboard_service.generate_rag_profile(
|
||||
end_user_id=body.end_user_id,
|
||||
limit=body.limit,
|
||||
max_tags=body.max_tags,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
api_logger.info(f"RAG画像生产完成: {data}")
|
||||
return success(data=data, msg="RAG画像生产完成")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.models import User
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
@@ -103,7 +104,7 @@ async def create_tool(
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = service.create_tool(
|
||||
tool_id = await service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
@@ -113,6 +114,8 @@ async def create_tool(
|
||||
tags=request.tags
|
||||
)
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except BusinessException as e:
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
@@ -115,6 +114,7 @@ class Settings:
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
|
||||
@@ -33,6 +33,7 @@ class DialogExtractionResponse(BaseModel):
|
||||
|
||||
- is_related:对话与场景的相关性判定。
|
||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||
"""
|
||||
is_related: bool = Field(...)
|
||||
times: List[str] = Field(default_factory=list)
|
||||
@@ -41,6 +42,7 @@ class DialogExtractionResponse(BaseModel):
|
||||
contacts: List[str] = Field(default_factory=list)
|
||||
addresses: List[str] = Field(default_factory=list)
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
@@ -86,26 +88,17 @@ class SemanticPruner:
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
# 加载统一填充词库
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||
|
||||
# 判断是否为内置专门场景
|
||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
|
||||
# 自定义场景的本体类型列表(用于注入提示词)
|
||||
# 本体类型列表(用于注入提示词,所有场景均支持)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
|
||||
if self._is_builtin_scene:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
@@ -117,107 +110,27 @@ class SemanticPruner:
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
# _is_important_message 和 _importance_score 已移除:
|
||||
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
|
||||
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
判断顺序:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
@@ -229,25 +142,11 @@ class SemanticPruner:
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
@@ -432,14 +331,12 @@ class SemanticPruner:
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
is_builtin_scene=self._is_builtin_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"is_builtin_scene": self._is_builtin_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"language": self.language
|
||||
})
|
||||
@@ -504,62 +401,56 @@ class SemanticPruner:
|
||||
# 相关对话不剪枝
|
||||
return dialog
|
||||
|
||||
# 在不相关对话中,识别重要/不重要消息
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords
|
||||
)
|
||||
msgs = dialog.context.msgs
|
||||
imp_unrel_msgs: List[ConversationMessage] = []
|
||||
unimp_unrel_msgs: List[ConversationMessage] = []
|
||||
|
||||
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||
filler_ids: set = set()
|
||||
deletable: List[ConversationMessage] = []
|
||||
|
||||
for m in msgs:
|
||||
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
||||
imp_unrel_msgs.append(m)
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
pass # 保护消息:不加入任何桶,不会被删除
|
||||
elif self._is_filler_message(m):
|
||||
filler_ids.add(id(m))
|
||||
else:
|
||||
unimp_unrel_msgs.append(m)
|
||||
# 计算总删除目标数量
|
||||
deletable.append(m)
|
||||
|
||||
# 计算删除目标
|
||||
total_unrel = len(msgs)
|
||||
delete_target = int(total_unrel * proportion)
|
||||
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
|
||||
unimp_del_cap = len(unimp_unrel_msgs)
|
||||
max_capacity = max(0, len(msgs) - 1)
|
||||
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
|
||||
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
# 删除配额分配
|
||||
del_unimp = min(delete_target, unimp_del_cap)
|
||||
rem = delete_target - del_unimp
|
||||
del_imp = min(rem, imp_del_cap)
|
||||
|
||||
# 选取删除集合
|
||||
unimp_delete_ids = []
|
||||
imp_delete_ids = []
|
||||
if del_unimp > 0:
|
||||
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
|
||||
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
|
||||
if del_imp > 0:
|
||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
|
||||
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
|
||||
|
||||
# 统计实际删除数量(重要/不重要)
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept_msgs = []
|
||||
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
|
||||
# 优先删填充,再删其他可删消息(按出现顺序)
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in delete_targets:
|
||||
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
continue
|
||||
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
|
||||
actual_imp_deleted += 1
|
||||
continue
|
||||
kept_msgs.append(m)
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
if id(m) in filler_ids:
|
||||
to_delete_ids.add(id(m))
|
||||
for m in deletable:
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
to_delete_ids.add(id(m))
|
||||
|
||||
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
deleted_total = actual_unimp_deleted + actual_imp_deleted
|
||||
deleted_total = len(msgs) - len(kept_msgs)
|
||||
protected_count = len(msgs) - len(filler_ids) - len(deletable)
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
|
||||
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
|
||||
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
dialog.context = ConversationContext(msgs=kept_msgs)
|
||||
@@ -594,51 +485,64 @@ class SemanticPruner:
|
||||
result: List[DialogData] = []
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
|
||||
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
|
||||
async with semaphore:
|
||||
try:
|
||||
return await self._extract_dialog_important(dd.content)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
|
||||
return DialogExtractionResponse(is_related=True)
|
||||
|
||||
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
|
||||
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
|
||||
|
||||
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
|
||||
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
|
||||
)
|
||||
|
||||
# 判断是否需要详细日志
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
|
||||
if extraction.preserve_keywords:
|
||||
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
|
||||
|
||||
# 消息级分类:LLM保护 / 填充 / 其他可删
|
||||
llm_protected_msgs = [] # LLM 保护消息(preserve_tokens 命中):绝对不可删除
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
deletable_msgs = [] # 其余消息(按比例删除)
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护(LLM,不可删)")
|
||||
elif self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
deletable_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 可删")
|
||||
|
||||
# important_msgs 仅用于日志统计
|
||||
important_msgs = llm_protected_msgs
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
@@ -649,37 +553,23 @@ class SemanticPruner:
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
deleted_details = []
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
for idx, msg in filler_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,按出现顺序删可删消息
|
||||
for idx, msg in deletable_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
@@ -707,7 +597,7 @@ class SemanticPruner:
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,66 +1,25 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
场景特定配置 - 统一填充词库
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
|
||||
本模块仅保留统一填充词库(filler_phrases),用于识别无意义寒暄/表情/口头禅。
|
||||
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
"""场景特定的识别模式(仅保留填充词库)"""
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
"""场景配置注册表 - 所有场景共用统一填充词库"""
|
||||
|
||||
BASE_FILLERS: Set[str] = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
"是的", "对", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 服务类套话
|
||||
"请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "欢迎下次咨询",
|
||||
# 外呼套话
|
||||
"喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
|
||||
# 教育场景套话
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
@@ -81,246 +59,8 @@ class SceneConfigRegistry:
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
def get_config(cls, scene: str = "") -> ScenePatterns:
|
||||
"""所有场景统一返回同一份填充词库"""
|
||||
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||
输入:pruning_scene, ontology_classes, dialog_text, language
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -9,64 +9,71 @@
|
||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||
- addresses: [string],地址/地点相关文本
|
||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
|
||||
|
||||
要求:
|
||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{# ── 内置场景的固定说明 ── #}
|
||||
{% set builtin_scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{# ── 确定最终使用的场景说明 ── #}
|
||||
{% if is_builtin_scene %}
|
||||
{# 内置专门场景:使用固定说明 #}
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% else %}
|
||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{# ── 确定场景说明 ── #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{# 无本体类型时退化为通用说明 #}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务:
|
||||
1. 判断对话是否与指定场景相关;
|
||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||
|
||||
场景说明:{{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
{% if custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
【必须保留的内容(不可删除)】
|
||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
|
||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||
- 地址信息:地点、地址、位置 → addresses 字段
|
||||
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
|
||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
|
||||
|
||||
【可以删除的内容】
|
||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
|
||||
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
|
||||
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
|
||||
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
|
||||
|
||||
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
|
||||
例如:
|
||||
- "我好开心呀" → 包含情绪(开心),必须保留,preserve_keywords 中加入"开心"
|
||||
- "好喜欢打羽毛球呀" → 包含兴趣爱好(喜欢打羽毛球),必须保留,preserve_keywords 中加入"喜欢打羽毛球"
|
||||
- "我好难过" → 包含情绪(难过),必须保留,preserve_keywords 中加入"难过"
|
||||
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪,必须保留,preserve_keywords 中加入"开心"
|
||||
|
||||
---
|
||||
对话全文:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -80,15 +87,46 @@
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||
1. Determine whether the dialogue is relevant to the specified scene;
|
||||
2. Extract all important information fragments that must be preserved.
|
||||
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
{% if custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
[MUST PRESERVE (cannot be deleted)]
|
||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||
- Time information: dates, time points, durations, expiry dates → times field
|
||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
|
||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||
- Address information: locations, addresses, places → addresses field
|
||||
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
|
||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
|
||||
|
||||
[CAN BE DELETED]
|
||||
The following types of content are low-value and can be removed during pruning:
|
||||
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
|
||||
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
|
||||
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
|
||||
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
|
||||
|
||||
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
|
||||
Examples:
|
||||
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
|
||||
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
|
||||
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
|
||||
|
||||
---
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -102,6 +140,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
|
||||
|
||||
from .chunk_summary import generate_chunk_summary
|
||||
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
|
||||
from .chunk_insight import generate_chunk_insight
|
||||
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
|
||||
|
||||
__all__ = [
|
||||
"generate_chunk_summary",
|
||||
"extract_chunk_tags",
|
||||
"extract_chunk_persona",
|
||||
"generate_chunk_insight",
|
||||
"generate_chunk_insight_sections",
|
||||
]
|
||||
|
||||
@@ -1,213 +1,207 @@
|
||||
"""
|
||||
Generate insights from RAG chunks.
|
||||
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
|
||||
|
||||
This module provides functionality to analyze chunk content and generate insights using LLM.
|
||||
The memory_insight.jinja2 template produces a four-section report:
|
||||
【总体概述】 → memory_insight
|
||||
【行为模式】 → behavior_pattern
|
||||
【关键发现】 → key_findings
|
||||
【成长轨迹】 → growth_trajectory
|
||||
|
||||
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
|
||||
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
"""Pydantic model for chunk insight."""
|
||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
||||
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
|
||||
|
||||
async def _classify_domain(chunk: str, llm_client) -> str:
|
||||
"""Classify a single chunk into a domain category."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DomainClassification(BaseModel):
|
||||
"""Pydantic model for domain classification."""
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="内容所属的领域分类",
|
||||
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
|
||||
)
|
||||
class _Domain(BaseModel):
|
||||
domain: str = Field(..., description="领域分类")
|
||||
|
||||
|
||||
async def classify_chunk_domain(chunk: str) -> str:
|
||||
"""
|
||||
Classify a chunk into a specific domain.
|
||||
|
||||
Args:
|
||||
chunk: Chunk content string
|
||||
|
||||
Returns:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
|
||||
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
|
||||
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
|
||||
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
|
||||
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
|
||||
- 健康:医疗、养生、运动、心理、保健、疾病等
|
||||
- 其他:无法归入以上类别的内容
|
||||
|
||||
文本内容: {chunk[:500]}...
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
classification = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DomainClassification
|
||||
prompt = (
|
||||
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
|
||||
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
|
||||
)
|
||||
|
||||
return classification.domain if classification else "其他"
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分类chunk领域失败: {str(e)}")
|
||||
result = await llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=_Domain,
|
||||
)
|
||||
return result.domain if result else "其他"
|
||||
except Exception:
|
||||
return "其他"
|
||||
|
||||
|
||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
||||
async def _build_insight_inputs(
|
||||
chunks: List[str],
|
||||
max_chunks: int,
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Analyze the domain distribution of chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary of domain -> percentage
|
||||
Derive domain_distribution, active_periods, social_connections strings
|
||||
to feed into the memory_insight.jinja2 template.
|
||||
"""
|
||||
if not chunks:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# 限制分析的chunk数量
|
||||
chunks_to_analyze = chunks[:max_chunks]
|
||||
|
||||
# 为每个chunk分类
|
||||
domain_counts = Counter()
|
||||
for chunk in chunks_to_analyze:
|
||||
domain = await classify_chunk_domain(chunk)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
# 计算百分比
|
||||
total = sum(domain_counts.values())
|
||||
domain_distribution = {
|
||||
domain: count / total
|
||||
for domain, count in domain_counts.items()
|
||||
}
|
||||
|
||||
# 按百分比降序排序
|
||||
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分析领域分布失败: {str(e)}")
|
||||
return {}
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
chunks_sample = chunks[:max_chunks]
|
||||
|
||||
# Domain distribution
|
||||
domain_counts: Counter = Counter()
|
||||
for chunk in chunks_sample:
|
||||
domain = await _classify_domain(chunk, llm_client)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
total = sum(domain_counts.values()) or 1
|
||||
domain_distribution = ", ".join(
|
||||
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
|
||||
)
|
||||
|
||||
return {
|
||||
"domain_distribution": domain_distribution,
|
||||
"active_periods": None, # RAG模式暂无时间维度数据
|
||||
"social_connections": None, # RAG模式暂无社交关联数据
|
||||
}
|
||||
|
||||
|
||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
||||
# ── Section parser ────────────────────────────────────────────────────────────
|
||||
|
||||
_ZH_SECTIONS = {
|
||||
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
|
||||
"key_findings": r"【关键发现】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
_EN_SECTIONS = {
|
||||
"memory_insight": r"【Overview】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
|
||||
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
|
||||
"""Extract the four sections from the LLM output."""
|
||||
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
|
||||
result = {}
|
||||
for key, pattern in patterns.items():
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
result[key] = match.group(1).strip() if match else ""
|
||||
return result
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_insight(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate insights from the given chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
A comprehensive insight report
|
||||
Generate a memory insight report from RAG chunks.
|
||||
|
||||
Returns the full raw report text (suitable for end_user.memory_insight).
|
||||
Use generate_chunk_insight_sections() when you need all four dimensions.
|
||||
"""
|
||||
sections = await generate_chunk_insight_sections(
|
||||
chunks=chunks,
|
||||
max_chunks=max_chunks,
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
|
||||
|
||||
|
||||
async def generate_chunk_insight_sections(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a four-section memory insight report from RAG chunks.
|
||||
|
||||
Returns a dict with keys:
|
||||
memory_insight, behavior_pattern, key_findings, growth_trajectory
|
||||
(plus '_raw' containing the full LLM output for debugging)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成洞察")
|
||||
return "暂无足够数据生成洞察报告"
|
||||
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "暂无足够数据生成洞察报告"
|
||||
return empty
|
||||
|
||||
try:
|
||||
# 1. 分析领域分布
|
||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
||||
|
||||
# 2. 统计基本信息
|
||||
total_chunks = len(chunks)
|
||||
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
|
||||
|
||||
# 3. 构建洞察prompt
|
||||
prompt_parts = []
|
||||
|
||||
if domain_dist:
|
||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
||||
prompt_parts.append(f"- 内容领域分布: {top_domains}")
|
||||
|
||||
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}字")
|
||||
|
||||
# 添加部分chunk内容作为参考
|
||||
sample_chunks = chunks[:5]
|
||||
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
|
||||
prompt_parts.append(f"\n内容示例:\n{sample_content}")
|
||||
|
||||
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
|
||||
重要规则:
|
||||
1. 报告需要将所有要点流畅地串联成一个段落
|
||||
2. 语言风格要专业、客观,同时易于理解
|
||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
||||
4. 基于提供的数据和示例内容进行分析,不要编造信息
|
||||
5. 重点关注内容的主题、特点和价值
|
||||
6. 报告长度控制在150-200字
|
||||
# Build template inputs from chunk analysis
|
||||
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
|
||||
|
||||
例如,如果输入是:
|
||||
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
|
||||
- 内容规模: 共50个知识片段,平均长度320字
|
||||
内容示例: [示例内容...]
|
||||
rendered_prompt = await render_memory_insight_prompt(
|
||||
domain_distribution=inputs["domain_distribution"],
|
||||
active_periods=inputs["active_periods"],
|
||||
social_connections=inputs["social_connections"],
|
||||
language=language,
|
||||
)
|
||||
|
||||
你的输出应该类似:
|
||||
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段,平均每个片段约320字,内容详实。从示例来看,内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
|
||||
"""
|
||||
|
||||
user_prompt = "\n".join(prompt_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = _get_llm_client()
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk洞察,分析了 {min(len(chunks), max_chunks)} 个片段")
|
||||
|
||||
return insight
|
||||
|
||||
raw_text = response.content.strip() if response and response.content else ""
|
||||
|
||||
sections = _parse_sections(raw_text, language=language)
|
||||
sections["_raw"] = raw_text
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk洞察四维度,分析了 {min(len(chunks), max_chunks)} 个片段"
|
||||
)
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk洞察失败: {str(e)}")
|
||||
return "洞察生成失败"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"Python是一种高级编程语言,以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
|
||||
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
|
||||
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
|
||||
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
|
||||
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
|
||||
]
|
||||
|
||||
print("开始生成chunk洞察...")
|
||||
insight = asyncio.run(generate_chunk_insight(test_chunks))
|
||||
print(f"\n生成的洞察:\n{insight}")
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "洞察生成失败"
|
||||
return empty
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
Generate summary for RAG chunks.
|
||||
|
||||
This module provides functionality to summarize chunk content using LLM.
|
||||
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
"""Pydantic model for chunk summary."""
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
# ── Schema ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MemorySummaryStatement(BaseModel):
|
||||
"""Single labelled statement extracted by memory_summary.jinja2."""
|
||||
statement: str = Field(..., description="提取的陈述内容")
|
||||
label: Optional[str] = Field(None, description="陈述标签")
|
||||
|
||||
|
||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
||||
class MemorySummaryResponse(BaseModel):
|
||||
"""
|
||||
Generate a summary for the given chunks.
|
||||
|
||||
Structured output expected from memory_summary.jinja2.
|
||||
The template asks for a JSON array of labelled statements;
|
||||
we wrap it in an object so response_structured can parse it.
|
||||
"""
|
||||
statements: List[MemorySummaryStatement] = Field(
|
||||
default_factory=list,
|
||||
description="从chunk中提取的陈述列表"
|
||||
)
|
||||
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
|
||||
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
# ── Core function ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_summary(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 10,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
|
||||
|
||||
The template extracts labelled statements from the chunks; we then join them
|
||||
into a coherent summary string that can be stored in end_user.user_summary.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to process (default: 10)
|
||||
|
||||
max_chunks: Maximum number of chunks to process
|
||||
end_user_id: Optional end-user ID for model selection
|
||||
language: Output language ("zh" or "en")
|
||||
|
||||
Returns:
|
||||
A concise summary of the chunks
|
||||
Summary string (joined statements or fallback text)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成摘要")
|
||||
return "暂无内容"
|
||||
|
||||
|
||||
try:
|
||||
# 限制处理的chunk数量,避免token过多
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
|
||||
chunks_to_process = chunks[:max_chunks]
|
||||
|
||||
# 合并chunk内容
|
||||
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
|
||||
|
||||
# 构建prompt
|
||||
system_prompt = (
|
||||
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
|
||||
"- 摘要长度控制在100-150字;\n"
|
||||
"- 提取核心信息和关键要点;\n"
|
||||
"- 使用客观、清晰的语言;\n"
|
||||
"- 避免冗余和重复;\n"
|
||||
"- 如果内容涉及多个主题,按重要性排序呈现。"
|
||||
chunk_texts = "\n\n".join(
|
||||
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
|
||||
)
|
||||
|
||||
json_schema = MemorySummaryResponse.model_json_schema()
|
||||
|
||||
rendered_prompt = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# Try structured output; fall back to plain chat only for LLMClientException
|
||||
# (indicates the model/provider doesn't support structured output).
|
||||
# All other exceptions are re-raised so config/schema errors stay visible.
|
||||
try:
|
||||
response: MemorySummaryResponse = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
if response.summary:
|
||||
summary = response.summary.strip()
|
||||
elif response.statements:
|
||||
summary = ";".join(s.statement for s in response.statements)
|
||||
else:
|
||||
summary = "暂无内容"
|
||||
except Exception as e:
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
if isinstance(e, LLMClientException):
|
||||
business_logger.warning(
|
||||
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
|
||||
)
|
||||
raw = await llm_client.chat(messages=messages)
|
||||
summary = raw.content.strip() if raw and raw.content else "暂无内容"
|
||||
else:
|
||||
business_logger.error(f"生成摘要时发生非预期异常: {e}")
|
||||
raise
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段"
|
||||
)
|
||||
|
||||
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk摘要失败: {str(e)}")
|
||||
return "摘要生成失败"
|
||||
|
||||
|
||||
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
|
||||
"""
|
||||
Generate summaries for multiple chunk lists in batch.
|
||||
|
||||
Args:
|
||||
chunks_list: List of chunk lists
|
||||
|
||||
Returns:
|
||||
List of summaries
|
||||
"""
|
||||
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
|
||||
"第二段内容介绍了深度学习的应用场景和发展历史。",
|
||||
"第三段讨论了自然语言处理技术的最新进展。"
|
||||
]
|
||||
|
||||
print("开始生成chunk摘要...")
|
||||
summary = asyncio.run(generate_chunk_summary(test_chunks))
|
||||
print(f"\n生成的摘要:\n{summary}")
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||
|
||||
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Extract meaningful tags from the given chunks.
|
||||
|
||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||
|
||||
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract persona (人物形象) from the given chunks.
|
||||
|
||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -85,6 +85,7 @@ class StorageFactory:
|
||||
access_key_id=settings.S3_ACCESS_KEY_ID,
|
||||
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
|
||||
bucket_name=settings.S3_BUCKET_NAME,
|
||||
endpoint_url=settings.S3_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -35,6 +35,19 @@ class S3Storage(StorageBackend):
|
||||
bucket_name: The name of the S3 bucket.
|
||||
region: The AWS region.
|
||||
"""
|
||||
AMAZON_S3_ENDPOINT_MAP = {
|
||||
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
|
||||
"us-east-2": "https://s3.us-east-2.amazonaws.com",
|
||||
"us-west-1": "https://s3.us-west-1.amazonaws.com",
|
||||
"us-west-2": "https://s3.us-west-2.amazonaws.com",
|
||||
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
|
||||
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
|
||||
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
|
||||
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
|
||||
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
|
||||
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
|
||||
# 可根据需要扩展其他地域
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,6 +55,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
endpoint_url: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the S3Storage backend.
|
||||
@@ -51,6 +65,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: The AWS access key ID.
|
||||
secret_access_key: The AWS secret access key.
|
||||
bucket_name: The name of the S3 bucket.
|
||||
endpoint_url: The complete URL to use for the constructed client.
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +84,19 @@ class S3Storage(StorageBackend):
|
||||
self.region = region
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
if not endpoint_url:
|
||||
# 优先匹配内置映射表(解决特殊地域)
|
||||
if region in self.AMAZON_S3_ENDPOINT_MAP:
|
||||
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
|
||||
# 兜底:通用拼接(适配未配置的新地域)
|
||||
else:
|
||||
endpoint_url = f"https://s3.{region}.amazonaws.com"
|
||||
|
||||
try:
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
region_name=region,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
)
|
||||
|
||||
@@ -4,65 +4,145 @@
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
MemoryBearConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
return self.config.get("workflow").get("nodes") or []
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
return self.config.get("workflow").get("edges") or []
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
return self.config.get("workflow").get("variables") or []
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
support_node_types=list(VALID_NODE_TYPES)
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
def map_node_type(self, platform_node_type: str) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
def _valid_node(node: dict[str, Any]) -> bool:
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
if not isinstance(node.get("config"), dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
if not self._valid_node(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_id = node.get("id")
|
||||
node_name = node.get("name")
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
return None
|
||||
|
||||
config = node.get("config") or {}
|
||||
converter = self.get_node_convert(node_type)
|
||||
converter(node_id, node_name, config) # validates and appends errors if invalid
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=f"convert node error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
for node in self.origin_nodes:
|
||||
converted = self._convert_node(node)
|
||||
if converted:
|
||||
self.nodes.append(converted)
|
||||
|
||||
valid_node_ids = {n.id for n in self.nodes}
|
||||
|
||||
for edge in self.origin_edges:
|
||||
converted = self._convert_edge(edge, valid_node_ids)
|
||||
if converted:
|
||||
self.edges.append(converted)
|
||||
|
||||
for variable in self.origin_variables:
|
||||
converted = self._convert_variable(variable)
|
||||
if converted:
|
||||
self.conv_variables.append(converted)
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
EndNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AgentNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
VariableAggregatorNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
ToolNodeConfig,
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class MemoryBearConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
|
||||
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
|
||||
NodeType.START: StartNodeConfig,
|
||||
NodeType.END: EndNodeConfig,
|
||||
NodeType.ANSWER: EndNodeConfig,
|
||||
NodeType.LLM: LLMNodeConfig,
|
||||
NodeType.AGENT: AgentNodeConfig,
|
||||
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
|
||||
NodeType.ASSIGNER: AssignerNodeConfig,
|
||||
NodeType.CODE: CodeNodeConfig,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
|
||||
NodeType.JINJARENDER: JinjaRenderNodeConfig,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
|
||||
NodeType.LOOP: LoopNodeConfig,
|
||||
NodeType.ITERATION: IterationNodeConfig,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
|
||||
NodeType.TOOL: ToolNodeConfig,
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
return []
|
||||
|
||||
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
def get_node_convert(self, node_type: NodeType):
|
||||
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
|
||||
if not config_cls:
|
||||
return lambda node_id, node_name, config: config
|
||||
|
||||
def validate(node_id: str, node_name: str, config: dict):
|
||||
self.config_validate(node_id, node_name, config_cls, config)
|
||||
return config
|
||||
|
||||
return validate
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
@@ -643,15 +644,18 @@ class BaseNode(ABC):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
origin_file_type=content.origin_file_type,
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
|
||||
|
||||
class HttpAuthConfig(BaseModel):
|
||||
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
|
||||
description="Http response headers"
|
||||
)
|
||||
|
||||
files: list[FileObject] = Field(
|
||||
default_factory=list,
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="SUCCESS",
|
||||
description="HTTP response body",
|
||||
|
||||
@@ -1,24 +1,146 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
import imghdr
|
||||
from email.message import Message
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
import magic
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class HttpResponse:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
self._is_file: bool | None = None
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
if self._is_file is not None:
|
||||
return self._is_file
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
|
||||
parsed_content_disposition = self.content_disposition
|
||||
if parsed_content_disposition:
|
||||
disp_type = parsed_content_disposition.get_content_disposition()
|
||||
filename = parsed_content_disposition.get_filename()
|
||||
if disp_type == "attachment" or filename:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
if content_type.startswith("application/"):
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
|
||||
):
|
||||
self._is_file = False
|
||||
return False
|
||||
try:
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
return self._is_file
|
||||
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
return self._is_file
|
||||
|
||||
@property
|
||||
def is_image(self):
|
||||
if self.is_file:
|
||||
kind = imghdr.what(None, h=self.response.content)
|
||||
return kind is not None
|
||||
return False
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return str(self.response.url)
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
if self.is_file:
|
||||
return f"{'!' if self.is_image else ''}[file]({self.url})"
|
||||
return self.response.text
|
||||
|
||||
@staticmethod
|
||||
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
|
||||
mime = magic.from_buffer(file_bytes, mime=True)
|
||||
|
||||
if mime.startswith("image"):
|
||||
return FileType.IMAGE, mime
|
||||
elif mime.startswith("video"):
|
||||
return FileType.VIDEO, mime
|
||||
elif mime.startswith("audio"):
|
||||
return FileType.AUDIO, mime
|
||||
elif mime in ["application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"text/plain"]:
|
||||
return FileType.DOCUMENT, mime
|
||||
return None, None
|
||||
|
||||
@property
|
||||
def files(self) -> list[FileObject]:
|
||||
file_type, mime_type = self.get_file_type(self.response.content)
|
||||
origin_file_type = mime_to_file_type(mime_type)
|
||||
if self.is_file and file_type and origin_file_type:
|
||||
file_obj = FileObject(
|
||||
type=file_type,
|
||||
url=self.url,
|
||||
transfer_method=TransferMethod.REMOTE_URL.value,
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
is_file=True
|
||||
)
|
||||
file_obj.set_content(self.response.content)
|
||||
return [
|
||||
file_obj
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
HTTP Request Workflow Node.
|
||||
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"files": VariableType.ARRAY_FILE,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files
|
||||
).model_dump()
|
||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||
logger.error(f"HTTP request node exception: {e}")
|
||||
|
||||
56
api/app/core/workflow/utils/file_processer.py
Normal file
56
api/app/core/workflow/utils/file_processer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
'text/x-markdown': 'document/x-markdown',
|
||||
|
||||
'application/pdf': 'document/pdf',
|
||||
|
||||
'application/msword': 'document/doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||
|
||||
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||
}
|
||||
ALLOWED_FILE_TYPES = [
|
||||
'text/plain',
|
||||
'text/markdown',
|
||||
'text/x-markdown',
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'image/jpg',
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
'image/svg+xml',
|
||||
'video/mp4',
|
||||
'video/quicktime',
|
||||
'video/x-msvideo',
|
||||
'video/x-matroska',
|
||||
'video/webm',
|
||||
'video/x-flv',
|
||||
'video/x-ms-wmv',
|
||||
'audio/mpeg',
|
||||
'audio/wav',
|
||||
'audio/ogg',
|
||||
'audio/aac',
|
||||
'audio/flac',
|
||||
'audio/mp4',
|
||||
'audio/x-ms-wma',
|
||||
'audio/x-m4a',
|
||||
]
|
||||
|
||||
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
@@ -114,9 +114,16 @@ class FileObject(BaseModel):
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
_byte_content: bytes | None = None
|
||||
|
||||
def get_content(self):
|
||||
return self._byte_content
|
||||
|
||||
def set_content(self, byte_content):
|
||||
self._byte_content = byte_content
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
@@ -51,6 +51,12 @@ class EndUser(Base):
|
||||
growth_trajectory = Column(Text, nullable=True, comment="成长轨迹")
|
||||
memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间")
|
||||
|
||||
# RAG存储模式专用字段 - RAG Storage Mode Fields
|
||||
# storage_type = Column(String, nullable=True, default="neo4j", comment="存储模式类型: neo4j / rag")
|
||||
rag_tags = Column(Text, nullable=True, comment="RAG模式下提取的标签列表(JSON格式)")
|
||||
rag_personas = Column(Text, nullable=True, comment="RAG模式下提取的人物形象列表(JSON格式)")
|
||||
rag_summary_updated_at = Column(DateTime, nullable=True, comment="RAG摘要/标签/人物形象最后更新时间")
|
||||
|
||||
# 与 App 的反向关系
|
||||
app = relationship(
|
||||
"App",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from app.models.app_model import App
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.app_model import App
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -35,11 +36,27 @@ class AppRepository:
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def get_apps_by_name(self, app_name: str, app_type: str, workspace_id: uuid.UUID) -> List[App]:
|
||||
try:
|
||||
stmt = select(App).where(
|
||||
App.name == app_name,
|
||||
App.workspace_id == workspace_id,
|
||||
App.type == app_type,
|
||||
App.is_active.is_(True),
|
||||
)
|
||||
apps = self.db.execute(stmt).scalars().all()
|
||||
return list(apps)
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询名称 {app_name} 应用异常: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_workspace_id(workspace_id)
|
||||
|
||||
|
||||
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
|
||||
@@ -220,6 +220,88 @@ class EndUserRepository:
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_rag_summary_tags(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
user_summary: str,
|
||||
rag_tags: str,
|
||||
rag_personas: str,
|
||||
) -> bool:
|
||||
"""更新RAG模式下的用户摘要、标签和人物形象缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
user_summary: 用户摘要文本
|
||||
rag_tags: 标签列表(JSON字符串)
|
||||
rag_personas: 人物形象列表(JSON字符串)
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.user_summary: user_summary,
|
||||
EndUser.rag_tags: rag_tags,
|
||||
EndUser.rag_personas: rag_personas,
|
||||
EndUser.rag_summary_updated_at: datetime.datetime.now(),
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG摘要/标签/人物形象缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新RAG摘要缓存")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的RAG摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_rag_insight(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
memory_insight: str,
|
||||
) -> bool:
|
||||
"""更新RAG模式下的记忆洞察缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_insight: 洞察文本
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.memory_insight: memory_insight,
|
||||
EndUser.memory_insight_updated_at: datetime.datetime.now(),
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG洞察缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新RAG洞察缓存")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的RAG洞察缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户
|
||||
|
||||
|
||||
@@ -8,6 +8,13 @@ import logging
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Generator, Optional
|
||||
|
||||
|
||||
class TimeFilterUnavailableError(Exception):
|
||||
"""redis_client 不可用,无法执行时间轴筛选。
|
||||
|
||||
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||
"""
|
||||
|
||||
import redis
|
||||
from sqlalchemy import exists, not_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -113,7 +120,7 @@ class ImplicitEmotionsStorageRepository:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_users_needing_refresh(self, redis_client: Optional[redis.StrictRedis], batch_size: int = 100) -> Generator[str, None, None]:
|
||||
def get_users_needing_refresh(self, redis_client: redis.StrictRedis, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
|
||||
|
||||
筛选逻辑:
|
||||
@@ -123,27 +130,21 @@ class ImplicitEmotionsStorageRepository:
|
||||
- 若 last_done > updated_at,说明上次刷新后又有新记忆写入,需要刷新
|
||||
- 若 last_done <= updated_at,说明已是最新,跳过
|
||||
|
||||
如果 redis_client 为 None,则降级为返回所有用户(禁用时间过滤)。
|
||||
|
||||
Args:
|
||||
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB),如果为 None 则禁用时间过滤
|
||||
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB)
|
||||
batch_size: 每批次加载的数量
|
||||
|
||||
Raises:
|
||||
TimeFilterUnavailableError: redis_client 为 None 时抛出,调用方可捕获并回退到 get_all_user_ids
|
||||
|
||||
Yields:
|
||||
需要刷新的用户ID字符串
|
||||
"""
|
||||
from datetime import timezone
|
||||
if redis_client is None:
|
||||
raise TimeFilterUnavailableError("redis_client 不可用,无法执行时间轴筛选")
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
# 如果 Redis 不可用,降级为处理所有用户
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
"Redis 客户端不可用,时间过滤已禁用,将处理所有存量用户"
|
||||
)
|
||||
yield from self.get_all_user_ids(batch_size)
|
||||
return
|
||||
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -178,16 +179,14 @@ class ImplicitEmotionsStorageRepository:
|
||||
try:
|
||||
CST = timezone(timedelta(hours=8))
|
||||
last_done = datetime.fromisoformat(raw)
|
||||
# 统一转为 CST naive 时间做比较
|
||||
if last_done.tzinfo is None:
|
||||
last_done = last_done.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||
else:
|
||||
# last_done 写入时已是 CST naive,直接使用,无需转换
|
||||
if last_done.tzinfo is not None:
|
||||
last_done = last_done.astimezone(CST).replace(tzinfo=None)
|
||||
|
||||
if updated_at is None:
|
||||
yield end_user_id
|
||||
continue
|
||||
# updated_at 同样转为 CST naive
|
||||
# updated_at 数据库存的是 UTC naive,转为 CST naive 再比较
|
||||
if updated_at.tzinfo is None:
|
||||
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from app.models.user_model import User
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||
from app.schemas.workspace_schema import WorkspaceCreate
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -19,7 +22,7 @@ class WorkspaceRepository:
|
||||
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
||||
"""创建工作空间"""
|
||||
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_workspace = Workspace(
|
||||
name=workspace_data.name,
|
||||
@@ -34,7 +37,8 @@ class WorkspaceRepository:
|
||||
)
|
||||
self.db.add(db_workspace)
|
||||
self.db.flush()
|
||||
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||
db_logger.info(
|
||||
f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||
return db_workspace
|
||||
except Exception as e:
|
||||
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
|
||||
@@ -43,7 +47,7 @@ class WorkspaceRepository:
|
||||
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
|
||||
"""根据ID获取工作空间"""
|
||||
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
@@ -65,7 +69,7 @@ class WorkspaceRepository:
|
||||
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
@@ -89,7 +93,7 @@ class WorkspaceRepository:
|
||||
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
|
||||
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 首先获取用户信息以获取 tenant_id
|
||||
from app.models.user_model import User
|
||||
@@ -97,7 +101,7 @@ class WorkspaceRepository:
|
||||
if not user:
|
||||
db_logger.warning(f"用户不存在: user_id={user_id}")
|
||||
return []
|
||||
|
||||
|
||||
if user.is_superuser:
|
||||
# 超级用户获取对应tenantid所有工作空间
|
||||
workspaces = (
|
||||
@@ -109,7 +113,7 @@ class WorkspaceRepository:
|
||||
)
|
||||
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
|
||||
# 获取用户作为成员的工作空间
|
||||
member_workspaces = (
|
||||
self.db.query(Workspace)
|
||||
@@ -120,7 +124,7 @@ class WorkspaceRepository:
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
|
||||
return member_workspaces
|
||||
except Exception as e:
|
||||
@@ -130,7 +134,7 @@ class WorkspaceRepository:
|
||||
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取租户的所有工作空间"""
|
||||
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspaces = (
|
||||
self.db.query(Workspace)
|
||||
@@ -144,14 +148,32 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||
def get_workspaces_by_name(self, tenant_id: uuid.UUID, workspace_name: str) -> List[Workspace]:
|
||||
try:
|
||||
stmt = (
|
||||
select(Workspace)
|
||||
.where(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.name == workspace_name,
|
||||
Workspace.is_active.is_(True)
|
||||
)
|
||||
)
|
||||
|
||||
workspaces = self.db.execute(stmt).scalars().all()
|
||||
return list(workspaces)
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间失败: workspace_name={workspace_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID,
|
||||
role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||
"""添加工作空间成员"""
|
||||
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
||||
|
||||
|
||||
try:
|
||||
db_member = WorkspaceMember(
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
role=role
|
||||
)
|
||||
self.db.add(db_member)
|
||||
@@ -165,7 +187,7 @@ class WorkspaceRepository:
|
||||
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
"""获取工作空间成员"""
|
||||
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.user_id == user_id,
|
||||
@@ -173,7 +195,8 @@ class WorkspaceRepository:
|
||||
WorkspaceMember.is_active.is_(True),
|
||||
).first()
|
||||
if member:
|
||||
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||
db_logger.debug(
|
||||
f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
|
||||
return member
|
||||
@@ -199,7 +222,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
|
||||
"""按成员ID获取工作空间成员,并预加载 user 与 workspace 关系"""
|
||||
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
|
||||
@@ -214,7 +237,8 @@ class WorkspaceRepository:
|
||||
.first()
|
||||
)
|
||||
if member:
|
||||
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||
db_logger.debug(
|
||||
f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"成员不存在: member_id={member_id}")
|
||||
return member
|
||||
@@ -222,7 +246,8 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[
|
||||
WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.workspace_id == workspace_id,
|
||||
@@ -255,7 +280,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
@@ -271,7 +296,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
@@ -288,12 +313,18 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 保持向后兼容的函数
|
||||
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspace_by_id(workspace_id)
|
||||
|
||||
|
||||
def get_workspaces_by_name(db: Session, tenant_id: uuid.UUID, name: str) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_name(tenant_id, name)
|
||||
|
||||
|
||||
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_user(user_id)
|
||||
@@ -315,7 +346,7 @@ def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UU
|
||||
|
||||
|
||||
def add_member_to_workspace(
|
||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||
) -> WorkspaceMember:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.add_member(workspace_id, user_id, role)
|
||||
@@ -325,39 +356,43 @@ def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[Works
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_members_by_workspace(workspace_id)
|
||||
|
||||
|
||||
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_member_by_id(member_id)
|
||||
|
||||
|
||||
def update_member_role_in_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role(workspace_id, user_id, role)
|
||||
|
||||
|
||||
def remove_member_from_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.deactivate_member(workspace_id, user_id)
|
||||
|
||||
|
||||
def remove_member_from_workspace_by_id(
|
||||
db: Session,
|
||||
member_id: uuid.UUID,
|
||||
db: Session,
|
||||
member_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.delete_member_by_id(member_id)
|
||||
|
||||
|
||||
def update_member_role_by_id(
|
||||
db: Session,
|
||||
id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
db: Session,
|
||||
id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role_by_id(id, role)
|
||||
|
||||
@@ -45,11 +45,19 @@ class FileInput(BaseModel):
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
_content = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
def set_content(self, content: bytes):
|
||||
self._content = content
|
||||
|
||||
def get_content(self) -> bytes | None:
|
||||
return self._content
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
|
||||
445
api/app/services/app_dsl_service.py
Normal file
445
api/app/services/app_dsl_service.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""应用 DSL 导入导出服务"""
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.models import AgentConfig, MultiAgentConfig
|
||||
from app.models.app_model import App, AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.models_model import ModelConfig
|
||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
|
||||
|
||||
class AppDslService:
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
def export_dsl(self, app_id: uuid.UUID, release_id: Optional[uuid.UUID] = None) -> tuple[str, str]:
|
||||
"""构建应用 DSL yaml 字符串,返回 (yaml_str, filename)"""
|
||||
app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
|
||||
if not app:
|
||||
raise ResourceNotFoundException("应用", str(app_id))
|
||||
|
||||
meta = {
|
||||
"version": settings.SYSTEM_VERSION,
|
||||
"platform": "MemoryBear",
|
||||
"exported_at": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
app_meta = {
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"icon": app.icon,
|
||||
"icon_type": app.icon_type,
|
||||
"type": app.type,
|
||||
"tags": app.tags or [],
|
||||
}
|
||||
|
||||
if release_id is not None:
|
||||
return self._export_release(app, release_id, meta, app_meta)
|
||||
|
||||
return self._export_draft(app, meta, app_meta)
|
||||
|
||||
def _export_release(self, app: App, release_id: uuid.UUID, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
release = self.db.query(AppRelease).filter(
|
||||
AppRelease.app_id == app.id,
|
||||
AppRelease.id == release_id,
|
||||
AppRelease.is_active.is_(True)
|
||||
).first()
|
||||
if not release:
|
||||
raise ResourceNotFoundException("版本", str(release_id))
|
||||
|
||||
meta["release_version"] = release.version
|
||||
meta["release_name"] = release.version_name
|
||||
app_meta["name"] = release.name
|
||||
app_meta["description"] = release.description
|
||||
config_key = {
|
||||
AppType.AGENT: "agent_config",
|
||||
AppType.MULTI_AGENT: "multi_agent_config",
|
||||
AppType.WORKFLOW: "workflow"
|
||||
}.get(app.type, "config")
|
||||
config_data = self._enrich_release_config(app.type, release.config or {})
|
||||
dsl = {**meta, "app": app_meta, config_key: config_data}
|
||||
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
|
||||
|
||||
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
|
||||
if app_type == AppType.AGENT:
|
||||
enriched = {**cfg}
|
||||
if "default_model_config_id" in cfg:
|
||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||
if "knowledge_retrieval" in cfg:
|
||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||
if "tools" in cfg:
|
||||
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
||||
return enriched
|
||||
if app_type == AppType.MULTI_AGENT:
|
||||
enriched = {**cfg}
|
||||
if "default_model_config_id" in cfg:
|
||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||
if "master_agent_id" in cfg:
|
||||
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
|
||||
if "sub_agents" in cfg:
|
||||
enriched["sub_agents"] = self._enrich_sub_agents(cfg["sub_agents"])
|
||||
if "routing_rules" in cfg:
|
||||
enriched["routing_rules"] = [
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||
]
|
||||
return enriched
|
||||
return cfg
|
||||
|
||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
if app.type == AppType.WORKFLOW:
|
||||
config = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "workflow": config_data}
|
||||
|
||||
elif app.type == AppType.AGENT:
|
||||
config = self.db.query(AgentConfig).filter(AgentConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"system_prompt": config.system_prompt if config else None,
|
||||
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||
"knowledge_retrieval": self._enrich_knowledge_retrieval(config.knowledge_retrieval) if config else None,
|
||||
"memory": config.memory if config else None,
|
||||
"variables": config.variables if config else [],
|
||||
"tools": self._enrich_tools(config.tools) if config else [],
|
||||
"skills": config.skills if config else {},
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
config = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"orchestration_mode": config.orchestration_mode if config else None,
|
||||
"master_agent_name": config.master_agent_name if config else None,
|
||||
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||
"master_agent_ref": self._release_ref(config.master_agent_id) if config else None,
|
||||
"sub_agents": self._enrich_sub_agents(config.sub_agents) if config else [],
|
||||
"routing_rules": [
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (config.routing_rules or [])
|
||||
] if config else [],
|
||||
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"aggregation_strategy": config.aggregation_strategy if config else "merge",
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "multi_agent_config": config_data}
|
||||
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app.type}", BizCode.BAD_REQUEST)
|
||||
|
||||
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{app.name}.yaml"
|
||||
|
||||
def _to_dict(self, value):
|
||||
"""将 Pydantic 对象转为普通 dict,供 yaml.dump 安全序列化"""
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def _model_ref(self, model_config_id) -> Optional[dict]:
|
||||
if not model_config_id:
|
||||
return None
|
||||
m = self.db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
|
||||
return {"id": str(model_config_id), "name": m.name, "provider": m.provider, "type": m.type} if m else {"id": str(model_config_id)}
|
||||
|
||||
def _kb_ref(self, kb_id) -> Optional[dict]:
|
||||
if not kb_id:
|
||||
return None
|
||||
kb = self.db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||
return {"id": str(kb_id), "name": kb.name} if kb else {"id": str(kb_id)}
|
||||
|
||||
def _tool_ref(self, tool_id) -> Optional[dict]:
|
||||
if not tool_id:
|
||||
return None
|
||||
t = self.db.query(ToolConfigModel).filter(ToolConfigModel.id == tool_id).first()
|
||||
return {"id": str(tool_id), "name": t.name, "tool_type": t.tool_type} if t else {"id": str(tool_id)}
|
||||
|
||||
def _enrich_knowledge_retrieval(self, kr: Optional[dict]) -> Optional[dict]:
|
||||
if not kr:
|
||||
return kr
|
||||
kbs = [{**kb, "_ref": self._kb_ref(kb.get("kb_id"))} for kb in kr.get("knowledge_bases", [])]
|
||||
return {**kr, "knowledge_bases": kbs}
|
||||
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||
if not agent_id:
|
||||
return None
|
||||
a = self.db.query(App).filter(App.id == agent_id).first()
|
||||
return {"id": str(agent_id), "name": a.name} if a else {"id": str(agent_id)}
|
||||
|
||||
def _release_ref(self, release_id) -> Optional[dict]:
|
||||
if not release_id:
|
||||
return None
|
||||
r = self.db.query(AppRelease).filter(AppRelease.id == release_id).first()
|
||||
return {"id": str(release_id), "name": r.name, "version": r.version, "app_id": str(r.app_id)} if r else {"id": str(release_id)}
|
||||
|
||||
def _enrich_sub_agents(self, sub_agents: list) -> list:
|
||||
return [{**s, "_ref": self._agent_ref(s.get("agent_id"))} for s in (sub_agents or [])]
|
||||
|
||||
# ==================== 导入 ====================
|
||||
|
||||
def import_dsl(
|
||||
self,
|
||||
dsl: dict,
|
||||
workspace_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> tuple[App, list[str]]:
|
||||
"""解析 DSL,创建应用及配置,返回 (new_app, warnings)"""
|
||||
app_meta = dsl.get("app", {})
|
||||
app_type = app_meta.get("type")
|
||||
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.BAD_REQUEST)
|
||||
|
||||
warnings: list[str] = []
|
||||
now = datetime.datetime.now()
|
||||
|
||||
new_app = App(
|
||||
id=uuid.uuid4(),
|
||||
workspace_id=workspace_id,
|
||||
created_by=user_id,
|
||||
name=self._unique_app_name(app_meta.get("name", "导入应用"), workspace_id, app_type),
|
||||
description=app_meta.get("description"),
|
||||
icon=app_meta.get("icon"),
|
||||
icon_type=app_meta.get("icon_type"),
|
||||
type=app_type,
|
||||
visibility="private",
|
||||
status="draft",
|
||||
tags=app_meta.get("tags", []),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(new_app)
|
||||
self.db.flush()
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
cfg = dsl.get("agent_config") or {}
|
||||
self.db.add(AgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
system_prompt=cfg.get("system_prompt"),
|
||||
model_parameters=cfg.get("model_parameters"),
|
||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
|
||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||
variables=cfg.get("variables", []),
|
||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||
skills=cfg.get("skills", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
))
|
||||
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
cfg = dsl.get("multi_agent_config") or {}
|
||||
self.db.add(MultiAgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
|
||||
master_agent_name=cfg.get("master_agent_name"),
|
||||
model_parameters=cfg.get("model_parameters"),
|
||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||
master_agent_id=self._resolve_release(cfg.get("master_agent_ref"), warnings),
|
||||
sub_agents=self._resolve_sub_agents(cfg.get("sub_agents", []), warnings),
|
||||
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
|
||||
execution_config=cfg.get("execution_config", {}),
|
||||
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
))
|
||||
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
adapter = MemoryBearAdapter(dsl)
|
||||
if not adapter.validate_config():
|
||||
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
|
||||
result = adapter.parse_workflow()
|
||||
for e in result.errors:
|
||||
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
|
||||
for w in result.warnings:
|
||||
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
|
||||
wf = dsl.get("workflow") or {}
|
||||
WorkflowService(self.db).create_workflow_config(
|
||||
app_id=new_app.id,
|
||||
nodes=[n.model_dump() for n in result.nodes],
|
||||
edges=[e.model_dump() for e in result.edges],
|
||||
variables=[v.model_dump() for v in result.variables],
|
||||
execution_config=wf.get("execution_config", {}),
|
||||
triggers=wf.get("triggers", []),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(new_app)
|
||||
return new_app, warnings
|
||||
|
||||
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
||||
existing = {r[0] for r in self.db.query(App.name).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.type == app_type,
|
||||
App.is_active.is_(True)
|
||||
).all()}
|
||||
if name not in existing:
|
||||
return name
|
||||
counter = 1
|
||||
while f"{name}({counter})" in existing:
|
||||
counter += 1
|
||||
return f"{name}({counter})"
|
||||
|
||||
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
|
||||
if not ref:
|
||||
return None
|
||||
q = self.db.query(ModelConfig).filter(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.name == ref.get("name"),
|
||||
ModelConfig.is_active.is_(True)
|
||||
)
|
||||
if ref.get("provider"):
|
||||
q = q.filter(ModelConfig.provider == ref["provider"])
|
||||
if ref.get("type"):
|
||||
q = q.filter(ModelConfig.type == ref["type"])
|
||||
m = q.first()
|
||||
if not m:
|
||||
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return m.id if m else None
|
||||
|
||||
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
kb = self.db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.name == ref.get("name")
|
||||
).first()
|
||||
if not kb:
|
||||
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return str(kb.id) if kb else None
|
||||
|
||||
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
q = self.db.query(ToolConfigModel).filter(
|
||||
ToolConfigModel.tenant_id == tenant_id,
|
||||
ToolConfigModel.name == ref.get("name")
|
||||
)
|
||||
if ref.get("tool_type"):
|
||||
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
|
||||
t = q.first()
|
||||
if not t:
|
||||
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return str(t.id) if t else None
|
||||
|
||||
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
|
||||
if not ref:
|
||||
return None
|
||||
r = self.db.query(AppRelease).filter(
|
||||
AppRelease.app_id == ref.get("app_id"),
|
||||
AppRelease.version == ref.get("version"),
|
||||
AppRelease.is_active.is_(True)
|
||||
).first()
|
||||
if not r:
|
||||
warnings.append(f"主 Agent 发布版本 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return r.id if r else None
|
||||
|
||||
def _resolve_sub_agents(self, sub_agents: list, warnings: list) -> list:
|
||||
result = []
|
||||
for s in (sub_agents or []):
|
||||
ref = s.get("_ref")
|
||||
entry = {k: v for k, v in s.items() if k != "_ref"}
|
||||
if ref:
|
||||
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||
if not a:
|
||||
warnings.append(f"子 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
entry["agent_id"] = str(a.id) if a else None
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
def _resolve_routing_rules(self, rules: Optional[list], warnings: list) -> Optional[list]:
|
||||
if rules is None:
|
||||
return None
|
||||
result = []
|
||||
for r in rules:
|
||||
ref = r.get("_ref")
|
||||
entry = {k: v for k, v in r.items() if k != "_ref"}
|
||||
if ref:
|
||||
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||
if not a:
|
||||
warnings.append(f"路由目标 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
entry["target_agent_id"] = str(a.id) if a else None
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||
if not kr:
|
||||
return kr
|
||||
resolved_kbs = []
|
||||
for kb in kr.get("knowledge_bases", []):
|
||||
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
|
||||
entry = {k: v for k, v in kb.items() if k != "_ref"}
|
||||
resolved_id = self._resolve_kb(ref, workspace_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["kb_id"] = resolved_id
|
||||
resolved_kbs.append(entry)
|
||||
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
|
||||
|
||||
def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||
if not memory:
|
||||
return memory
|
||||
config_id = memory.get("memory_config_id") or memory.get("memory_content")
|
||||
if not config_id:
|
||||
return memory
|
||||
try:
|
||||
config_uuid = uuid.UUID(str(config_id))
|
||||
except (ValueError, AttributeError):
|
||||
exists = self.db.query(MemoryConfigModel).filter(
|
||||
MemoryConfigModel.config_id_old == int(config_id),
|
||||
MemoryConfigModel.workspace_id == workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
exists = self.db.query(MemoryConfigModel).filter(
|
||||
MemoryConfigModel.config_id == config_uuid,
|
||||
MemoryConfigModel.workspace_id == workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
|
||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||
result = []
|
||||
for t in (tools or []):
|
||||
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
|
||||
entry = {k: v for k, v in t.items() if k != "_ref"}
|
||||
resolved_id = self._resolve_tool(ref, tenant_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["tool_id"] = resolved_id
|
||||
result.append(entry)
|
||||
return result
|
||||
@@ -33,7 +33,7 @@ from app.models import (
|
||||
Workspace,
|
||||
)
|
||||
from app.models.app_model import AppStatus, AppType
|
||||
from app.repositories.app_repository import get_apps_by_id
|
||||
from app.repositories.app_repository import get_apps_by_id, AppRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
@@ -59,6 +59,7 @@ class AppService:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.app_repo = AppRepository(self.db)
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
@@ -521,6 +522,9 @@ class AppService:
|
||||
"创建应用",
|
||||
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
|
||||
)
|
||||
apps = self.app_repo.get_apps_by_name(data.name, data.type, workspace_id)
|
||||
if apps:
|
||||
raise BusinessException(message="已存在同名应用", code=BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
@@ -1368,6 +1372,15 @@ class AppService:
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
miss_params = []
|
||||
if agent_cfg.default_model_config_id is None:
|
||||
miss_params.append("model config")
|
||||
|
||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||
miss_params.append("memory config")
|
||||
if miss_params:
|
||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||
|
||||
config = {
|
||||
"system_prompt": agent_cfg.system_prompt,
|
||||
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),
|
||||
|
||||
@@ -1165,6 +1165,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
@@ -1179,10 +1180,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
if not app:
|
||||
logger.warning(f"App not found: {app_id}")
|
||||
raise ValueError(f"应用不存在: {app_id}")
|
||||
|
||||
if not app.current_release_id:
|
||||
logger.warning(f"No current release for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
# TODO: temp fix for draft run
|
||||
# if not app.current_release_id:
|
||||
# logger.warning(f"No current release for app: {app_id}")
|
||||
# raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||||
memory_config_id_to_use = end_user.memory_config_id
|
||||
@@ -1223,7 +1224,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
if legacy_config_id:
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.models.memory_config_model import (
|
||||
MemoryConfig as MemoryConfigModel,
|
||||
)
|
||||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||||
|
||||
if existing_config:
|
||||
@@ -1257,7 +1260,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(app.current_release_id),
|
||||
"release_id": str(app.current_release_id) if app.current_release_id else None,
|
||||
"memory_config_id": memory_config_id,
|
||||
"workspace_id": str(app.workspace_id)
|
||||
}
|
||||
|
||||
@@ -107,29 +107,19 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
)
|
||||
|
||||
|
||||
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
|
||||
# 使用懒加载函数避免模块级循环导入
|
||||
def _get_builtin_pruning_scenes() -> set:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
|
||||
return set(SceneConfigRegistry.get_all_scenes())
|
||||
|
||||
|
||||
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
||||
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
|
||||
"""从 ontology_class 表加载场景类型名称列表,用于注入提示词。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
scene_id: 本体场景 UUID
|
||||
pruning_scene: 语义剪枝场景名称
|
||||
pruning_scene: 语义剪枝场景名称(保留参数,暂未使用)
|
||||
|
||||
Returns:
|
||||
class_name 字符串列表,或 None(内置场景 / 无数据时)
|
||||
class_name 字符串列表,或 None(无数据时)
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
# 内置场景走 SceneConfigRegistry,不需要注入类型列表
|
||||
if pruning_scene in _get_builtin_pruning_scenes():
|
||||
return None
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
repo = OntologyClassRepository(db)
|
||||
|
||||
@@ -535,7 +535,8 @@ def get_users_total_chunk_batch(
|
||||
|
||||
def get_rag_content(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
@@ -543,9 +544,9 @@ def get_rag_content(
|
||||
先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id,
|
||||
然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容,
|
||||
接着对获取的内容进行提取,只要page_content的内容,
|
||||
最后返回数据
|
||||
最后返回分页数据
|
||||
"""
|
||||
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
@@ -562,63 +563,76 @@ def get_rag_content(
|
||||
if not documents:
|
||||
business_logger.warning(f"未找到文件: {file_name}")
|
||||
return {
|
||||
"total": 0,
|
||||
"contents": []
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": 0,
|
||||
"hasnext": False,
|
||||
},
|
||||
"items": []
|
||||
}
|
||||
|
||||
business_logger.info(f"找到 {len(documents)} 个文档记录")
|
||||
|
||||
# 3. 获取所有chunks的page_content
|
||||
all_contents = []
|
||||
total_chunks = 0
|
||||
# 3. 按全局偏移量计算当前页数据
|
||||
# 全局偏移范围:[offset_start, offset_end)
|
||||
offset_start = (page - 1) * pagesize
|
||||
offset_end = offset_start + pagesize
|
||||
|
||||
global_total = 0 # 所有文档的 chunk 总数
|
||||
page_contents = [] # 当前页的内容
|
||||
|
||||
for document in documents:
|
||||
try:
|
||||
# 获取知识库信息
|
||||
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
|
||||
if not kb:
|
||||
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
|
||||
continue
|
||||
|
||||
# 初始化向量服务
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb)
|
||||
|
||||
# 获取该文档的所有chunks(分页获取)
|
||||
page = 1
|
||||
pagesize = 100 # 每页100条
|
||||
# 先用 pagesize=1 获取该文档的 chunk 总数
|
||||
doc_total, _ = vector_service.search_by_segment(
|
||||
document_id=str(document.id),
|
||||
query=None,
|
||||
pagesize=1,
|
||||
page=1,
|
||||
asc=True
|
||||
)
|
||||
|
||||
while True:
|
||||
total, items = vector_service.search_by_segment(
|
||||
doc_offset_start = global_total # 该文档在全局中的起始偏移
|
||||
doc_offset_end = global_total + doc_total # 该文档在全局中的结束偏移
|
||||
global_total += doc_total
|
||||
|
||||
# 当前页与该文档无交集,跳过
|
||||
if doc_offset_end <= offset_start or doc_offset_start >= offset_end:
|
||||
continue
|
||||
|
||||
# 计算需要从该文档取的局部范围
|
||||
local_start = max(offset_start - doc_offset_start, 0)
|
||||
local_end = min(offset_end - doc_offset_start, doc_total)
|
||||
need_count = local_end - local_start
|
||||
|
||||
# 换算成 ES 分页参数(ES page 从1开始)
|
||||
es_page = (local_start // pagesize) + 1
|
||||
es_offset_in_page = local_start % pagesize
|
||||
|
||||
fetched = []
|
||||
while len(fetched) < es_offset_in_page + need_count:
|
||||
_, items = vector_service.search_by_segment(
|
||||
document_id=str(document.id),
|
||||
query=None,
|
||||
pagesize=pagesize,
|
||||
page=page,
|
||||
page=es_page,
|
||||
asc=True
|
||||
)
|
||||
|
||||
if not items:
|
||||
break
|
||||
|
||||
# 提取page_content
|
||||
for item in items:
|
||||
all_contents.append(item.page_content)
|
||||
total_chunks += 1
|
||||
|
||||
# # 如果达到limit限制,直接返回
|
||||
# if limit > 0 and total_chunks >= limit:
|
||||
# business_logger.info(f"已达到limit限制: {limit}")
|
||||
# return {
|
||||
# "total": total_chunks,
|
||||
# "contents": all_contents[:limit]
|
||||
# }
|
||||
|
||||
# 检查是否还有下一页
|
||||
if page * pagesize >= total:
|
||||
break
|
||||
|
||||
page += 1
|
||||
fetched.extend(items)
|
||||
es_page += 1
|
||||
|
||||
business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks")
|
||||
slice_items = fetched[es_offset_in_page: es_offset_in_page + need_count]
|
||||
page_contents.extend([item.page_content for item in slice_items])
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
@@ -626,11 +640,16 @@ def get_rag_content(
|
||||
|
||||
# 4. 返回结果
|
||||
result = {
|
||||
"total": total_chunks,
|
||||
"contents": all_contents[:limit] if limit > 0 else all_contents
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": global_total,
|
||||
"hasnext": offset_end < global_total,
|
||||
},
|
||||
"items": page_contents
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])} 条")
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -646,59 +665,26 @@ async def get_chunk_summary_and_tags(
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的总结、标签和人物形象
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
max_tags: 最大标签数量
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含summary、tags和personas的字典
|
||||
纯读库:从end_user表返回RAG摘要、标签和人物形象缓存。
|
||||
无数据时返回空结构,不触发LLM生成。
|
||||
"""
|
||||
business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"summary": "暂无内容",
|
||||
"tags": [],
|
||||
"personas": []
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona
|
||||
|
||||
# 3. 并发生成摘要、提取标签和人物形象
|
||||
import asyncio
|
||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
|
||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
|
||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
|
||||
|
||||
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
||||
|
||||
# 4. 格式化标签数据
|
||||
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
|
||||
|
||||
result = {
|
||||
"summary": summary,
|
||||
"tags": tags,
|
||||
"personas": personas
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
import json
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
business_logger.info(f"读取chunk摘要/标签/人物形象缓存: end_user_id={end_user_id}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
return {"summary": "", "tags": [], "personas": [], "generated": False}
|
||||
|
||||
return {
|
||||
"summary": end_user.user_summary or "",
|
||||
"tags": json.loads(end_user.rag_tags) if end_user.rag_tags else [],
|
||||
"personas": json.loads(end_user.rag_personas) if end_user.rag_personas else [],
|
||||
"generated": bool(end_user.user_summary),
|
||||
}
|
||||
|
||||
|
||||
async def get_chunk_insight(
|
||||
@@ -708,43 +694,98 @@ async def get_chunk_insight(
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的洞察分析
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含insight的字典
|
||||
纯读库:从end_user表返回RAG洞察缓存。
|
||||
无数据时返回空结构,不触发LLM生成。
|
||||
"""
|
||||
business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"insight": "暂无足够数据生成洞察报告"
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_insight
|
||||
|
||||
# 3. 生成洞察
|
||||
insight = await generate_chunk_insight(chunks, max_chunks=limit)
|
||||
|
||||
result = {
|
||||
"insight": insight
|
||||
}
|
||||
|
||||
business_logger.info("成功获取chunk洞察")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
business_logger.info(f"读取chunk洞察缓存: end_user_id={end_user_id}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
return {"insight": "", "behavior_pattern": "", "key_findings": "", "growth_trajectory": "", "generated": False}
|
||||
|
||||
return {
|
||||
"insight": end_user.memory_insight or "",
|
||||
"behavior_pattern": end_user.behavior_pattern or "",
|
||||
"key_findings": end_user.key_findings or "",
|
||||
"growth_trajectory": end_user.growth_trajectory or "",
|
||||
"generated": bool(end_user.memory_insight),
|
||||
}
|
||||
|
||||
|
||||
async def generate_rag_profile(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
max_tags: int,
|
||||
db: Session,
|
||||
current_user: User,
|
||||
) -> dict:
|
||||
"""
|
||||
生产接口:为RAG存储模式的end_user全量重新生成并持久化完整画像数据。
|
||||
每次调用都会重新生成,覆盖已有数据。
|
||||
|
||||
生成内容:
|
||||
- user_summary / rag_tags / rag_personas
|
||||
- memory_insight / behavior_pattern / key_findings / growth_trajectory
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.core.rag_utils import (
|
||||
generate_chunk_summary,
|
||||
extract_chunk_tags,
|
||||
extract_chunk_persona,
|
||||
generate_chunk_insight_sections,
|
||||
)
|
||||
|
||||
business_logger.info(f"开始生产RAG画像: end_user_id={end_user_id}, 操作者: {current_user.username}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
raise ValueError(f"end_user {end_user_id} 不存在")
|
||||
|
||||
rag_content = get_rag_content(end_user_id, page=1, pagesize=limit, db=db, current_user=current_user)
|
||||
chunks = rag_content.get("items", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容,无法生产RAG画像: end_user_id={end_user_id}")
|
||||
raise ValueError("暂无chunk内容,无法生成画像")
|
||||
|
||||
summary, tags_with_freq, personas, insight_sections = await asyncio.gather(
|
||||
generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id),
|
||||
extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id),
|
||||
extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id),
|
||||
generate_chunk_insight_sections(chunks, max_chunks=limit, end_user_id=end_user_id),
|
||||
)
|
||||
|
||||
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
|
||||
|
||||
repo.update_rag_summary_tags(
|
||||
end_user_id=end_user.id,
|
||||
user_summary=summary,
|
||||
rag_tags=json.dumps(tags, ensure_ascii=False),
|
||||
rag_personas=json.dumps(personas, ensure_ascii=False),
|
||||
)
|
||||
|
||||
repo.update_memory_insight(
|
||||
end_user_id=end_user.id,
|
||||
memory_insight=insight_sections.get("memory_insight", ""),
|
||||
behavior_pattern=insight_sections.get("behavior_pattern", ""),
|
||||
key_findings=insight_sections.get("key_findings", ""),
|
||||
growth_trajectory=insight_sections.get("growth_trajectory", ""),
|
||||
)
|
||||
|
||||
business_logger.info(f"RAG画像生产完成: end_user_id={end_user_id}, tags={len(tags)}, personas={len(personas)}")
|
||||
|
||||
return {
|
||||
"end_user_id": end_user_id,
|
||||
"summary_length": len(summary),
|
||||
"tags_count": len(tags),
|
||||
"personas_count": len(personas),
|
||||
"insight_generated": bool(insight_sections.get("memory_insight")),
|
||||
}
|
||||
@@ -8,32 +8,42 @@
|
||||
- Bedrock/Anthropic: 仅支持 base64 格式
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import PyPDF2
|
||||
import httpx
|
||||
import magic
|
||||
from docx import Document
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
TEXT_MIME = ['text/plain', 'text/x-markdown']
|
||||
PDF_MIME = ['application/pdf']
|
||||
DOC_MIME = [
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
]
|
||||
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@@ -43,7 +53,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +66,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -70,7 +80,13 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
@@ -98,44 +114,37 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
content_type = magic.from_buffer(content, mime=True)
|
||||
media_type = content_type if content_type.startswith("image/") else "image/jpeg"
|
||||
base64_data = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
@@ -152,7 +161,12 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self, file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
@@ -178,7 +192,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
return {
|
||||
"type": "image_url",
|
||||
@@ -194,7 +208,13 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
@@ -208,31 +228,35 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
content_type = magic.from_buffer(audio_data, mime=True)
|
||||
file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
@@ -262,7 +286,8 @@ PROVIDER_STRATEGIES = {
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
@@ -305,10 +330,9 @@ class MultimodalService:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file, strategy)
|
||||
@@ -355,7 +379,7 @@ class MultimodalService:
|
||||
"""
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
return await strategy.format_image(url, content=file.get_content())
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -415,11 +439,13 @@ class MultimodalService:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -454,7 +480,7 @@ class MultimodalService:
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -500,8 +526,6 @@ class MultimodalService:
|
||||
return file.url
|
||||
else:
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
@@ -519,66 +543,44 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
file: 文件输入
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in TEXT_MIME:
|
||||
return file_content.decode("utf-8")
|
||||
elif file_mime_type in PDF_MIME:
|
||||
return await self._extract_pdf_text(file_content)
|
||||
elif file_mime_type in DOC_MIME:
|
||||
return await self._extract_word_text(file_content)
|
||||
else:
|
||||
return f"[Unsupported file type: {file_mime_type}]"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
logger.error(f"Failed to load file. - {e}")
|
||||
return "[Failed to load file.]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
async def _extract_pdf_text(file_content: bytes) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
@@ -588,17 +590,11 @@ class MultimodalService:
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
async def _extract_word_text(file_content: bytes) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
|
||||
@@ -120,7 +120,8 @@ async def run_pilot_extraction(
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"llm_model_id": str(memory_config.llm_model_id),
|
||||
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
"ontology_classes": memory_config.ontology_classes,
|
||||
}
|
||||
config = PruningConfig(**pruning_config_dict)
|
||||
|
||||
|
||||
@@ -93,7 +93,44 @@ class ToolService:
|
||||
if query.first():
|
||||
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
def create_tool(
|
||||
def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]):
|
||||
"""检查MCP工具是否重复:市场来源按market_id+market_config_id+mcp_service_id判断(名称无关),自建按name+tool_type判断"""
|
||||
from app.models.tool_model import MCPSourceChannel
|
||||
source_channel = config.get("source_channel")
|
||||
is_market_source = (
|
||||
source_channel is not None
|
||||
and source_channel != MCPSourceChannel.SELF_HOSTED
|
||||
)
|
||||
if is_market_source:
|
||||
exists = (
|
||||
self.db.query(ToolConfig)
|
||||
.join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id)
|
||||
.filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
MCPToolConfig.source_channel == source_channel,
|
||||
MCPToolConfig.market_id == config.get("market_id"),
|
||||
MCPToolConfig.market_config_id == config.get("market_config_id"),
|
||||
MCPToolConfig.mcp_service_id == config.get("mcp_service_id"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if exists:
|
||||
raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME)
|
||||
else:
|
||||
exists = (
|
||||
self.db.query(ToolConfig)
|
||||
.filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if exists:
|
||||
raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
async def create_tool(
|
||||
self,
|
||||
name: str,
|
||||
tool_type: ToolType,
|
||||
@@ -106,7 +143,19 @@ class ToolService:
|
||||
"""创建工具"""
|
||||
if tool_type == ToolType.BUILTIN:
|
||||
raise ValueError("内置工具不允许创建")
|
||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||
|
||||
cfg = config or {}
|
||||
if tool_type == ToolType.MCP:
|
||||
self._check_mcp_duplicate(name, tool_type, tenant_id, cfg)
|
||||
# 创建前测试连接
|
||||
test_result = await self._test_mcp_connection_by_config(cfg)
|
||||
if not test_result["success"]:
|
||||
raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER)
|
||||
# 将发现的工具列表写回 config
|
||||
if "available_tools" in test_result:
|
||||
cfg["available_tools"] = test_result["available_tools"]
|
||||
else:
|
||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||
|
||||
try:
|
||||
# 创建基础配置
|
||||
@@ -117,19 +166,22 @@ class ToolService:
|
||||
tool_type=tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
status=ToolStatus.AVAILABLE.value,
|
||||
config_data=config or {},
|
||||
config_data=cfg,
|
||||
tags=tags
|
||||
)
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建类型特定配置
|
||||
self._create_type_config(tool_config, config or {})
|
||||
self._create_type_config(tool_config, cfg)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"工具创建成功: {tool_config.id}")
|
||||
return str(tool_config.id)
|
||||
|
||||
except BusinessException:
|
||||
self.db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"创建工具失败: {e}")
|
||||
@@ -1165,6 +1217,27 @@ class ToolService:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""根据配置参数直接测试MCP连接(创建前调用,无需已存在的工具记录)"""
|
||||
server_url = config.get("server_url")
|
||||
if not server_url:
|
||||
return {"success": False, "message": "server_url不能为空"}
|
||||
connection_config = config.get("connection_config") or {}
|
||||
try:
|
||||
test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config)
|
||||
if not test_result["success"]:
|
||||
return test_result
|
||||
success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config)
|
||||
if not success_flag:
|
||||
return {"success": False, "message": f"获取工具列表失败: {error}"}
|
||||
tool_list = [
|
||||
{tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}}
|
||||
for tool in tools if tool.get("name")
|
||||
]
|
||||
return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"连接测试异常: {str(e)}"}
|
||||
|
||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试MCP连接并自动同步工具列表"""
|
||||
try:
|
||||
|
||||
@@ -458,7 +458,7 @@ class WorkflowService:
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id),
|
||||
file_id=str(file.upload_file_id) if file.upload_file_id else None,
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
|
||||
@@ -2,11 +2,11 @@ import datetime
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
@@ -30,17 +30,15 @@ from app.schemas.workspace_schema import (
|
||||
WorkspaceModelsUpdate,
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def switch_workspace(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""切换工作空间"""
|
||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -60,31 +58,32 @@ def switch_workspace(
|
||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def delete_workspace_member(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
member_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
def delete_workspace_member(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
member_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}",
|
||||
BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
@@ -102,19 +101,19 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
"""
|
||||
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
|
||||
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||
|
||||
|
||||
# Ensure each neo4j workspace has a default memory config
|
||||
for workspace in workspaces:
|
||||
if workspace.storage_type == 'neo4j':
|
||||
_ensure_default_memory_config(db, workspace)
|
||||
_ensure_default_ontology_scenes(db, workspace)
|
||||
|
||||
|
||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
|
||||
def _create_workspace_only(
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
) -> Workspace:
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||
|
||||
@@ -138,9 +137,14 @@ def create_workspace(
|
||||
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
|
||||
f"storage_type: {workspace.storage_type}"
|
||||
)
|
||||
llm=workspace.llm
|
||||
embedding=workspace.embedding
|
||||
rerank=workspace.rerank
|
||||
if workspace_repository.get_workspaces_by_name(db=db, name=workspace.name, tenant_id=user.tenant_id):
|
||||
raise BusinessException(
|
||||
message="同名工作空间已存在",
|
||||
code=BizCode.RESOURCE_ALREADY_EXISTS
|
||||
)
|
||||
llm = workspace.llm
|
||||
embedding = workspace.embedding
|
||||
rerank = workspace.rerank
|
||||
try:
|
||||
# Create the workspace without adding any members
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||
@@ -159,26 +163,26 @@ def create_workspace(
|
||||
success, error_msg = initializer.initialize_default_scenes(
|
||||
db_workspace.id, language=language
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
business_logger.info(
|
||||
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
|
||||
)
|
||||
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.config.default_ontology_config import (
|
||||
ONLINE_EDUCATION_SCENE,
|
||||
ONLINE_EDUCATION_SCENE,
|
||||
EMOTIONAL_COMPANION_SCENE,
|
||||
get_scene_name
|
||||
)
|
||||
|
||||
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
|
||||
|
||||
# 优先尝试获取教育场景
|
||||
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
||||
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
|
||||
|
||||
|
||||
if education_scene:
|
||||
default_scene_id = education_scene.scene_id
|
||||
default_scene_name = education_scene.scene_name
|
||||
@@ -189,7 +193,7 @@ def create_workspace(
|
||||
# 如果教育场景不存在,尝试获取情感陪伴场景
|
||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
|
||||
|
||||
|
||||
if companion_scene:
|
||||
default_scene_id = companion_scene.scene_id
|
||||
default_scene_name = companion_scene.scene_name
|
||||
@@ -256,10 +260,10 @@ def create_workspace(
|
||||
avatar='',
|
||||
type=KnowledgeType.General,
|
||||
permission_id=PermissionType.Memory,
|
||||
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
|
||||
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
|
||||
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
embedding_id=embedding,
|
||||
reranker_id=rerank,
|
||||
llm_id=llm,
|
||||
image2text_id=llm,
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 256,
|
||||
@@ -294,7 +298,7 @@ def create_workspace(
|
||||
business_logger.info(
|
||||
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
|
||||
)
|
||||
|
||||
|
||||
return db_workspace
|
||||
|
||||
except Exception as e:
|
||||
@@ -304,11 +308,11 @@ def create_workspace(
|
||||
|
||||
|
||||
def update_workspace(
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
) -> Workspace:
|
||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
||||
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
|
||||
try:
|
||||
# 更新工作空间
|
||||
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
|
||||
@@ -328,7 +332,7 @@ def update_workspace(
|
||||
|
||||
|
||||
def get_workspace_members(
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
) -> List[WorkspaceMember]:
|
||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
@@ -372,7 +376,6 @@ def get_workspace_members(
|
||||
return members
|
||||
|
||||
|
||||
|
||||
# ==================== 邀请相关服务方法 ====================
|
||||
|
||||
def _generate_invite_token() -> tuple[str, str]:
|
||||
@@ -465,13 +468,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
|
||||
|
||||
def create_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
) -> WorkspaceInviteResponse:
|
||||
"""创建工作空间邀请"""
|
||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
business_logger.info(
|
||||
f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
@@ -534,17 +538,18 @@ def create_workspace_invite(
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
business_logger.error(
|
||||
f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_invites(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[WorkspaceInviteResponse]:
|
||||
"""获取工作空间邀请列表"""
|
||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
@@ -605,9 +610,9 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
|
||||
|
||||
def accept_workspace_invite(
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""接受工作空间邀请"""
|
||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||
@@ -695,7 +700,8 @@ def accept_workspace_invite(
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
business_logger.info(
|
||||
f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
|
||||
return {
|
||||
"message": "Successfully joined the workspace",
|
||||
@@ -710,13 +716,14 @@ def accept_workspace_invite(
|
||||
|
||||
|
||||
def revoke_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""撤销工作空间邀请"""
|
||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
business_logger.info(
|
||||
f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
@@ -745,13 +752,14 @@ def revoke_workspace_invite(
|
||||
|
||||
|
||||
def update_workspace_member_roles(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
user: User,
|
||||
) -> List[WorkspaceMember]:
|
||||
"""更新工作空间成员角色"""
|
||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
business_logger.info(
|
||||
f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
|
||||
# 检查管理员权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
@@ -765,7 +773,8 @@ def update_workspace_member_roles(
|
||||
for upd in updates:
|
||||
# 检查成员是否存在
|
||||
if upd.id not in member_map:
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}",
|
||||
BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
member = member_map[upd.id]
|
||||
|
||||
@@ -917,10 +926,10 @@ def get_workspace_models_configs(
|
||||
|
||||
|
||||
def update_workspace_models_configs(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
user: User,
|
||||
) -> Workspace:
|
||||
"""更新工作空间的模型配置(llm, embedding, rerank)
|
||||
|
||||
@@ -968,8 +977,8 @@ def update_workspace_models_configs(
|
||||
|
||||
|
||||
def _fill_workspace_configs_model_defaults(
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
) -> None:
|
||||
"""Fill empty model fields for all memory configs in a workspace.
|
||||
|
||||
@@ -981,43 +990,43 @@ def _fill_workspace_configs_model_defaults(
|
||||
workspace: The workspace containing default model settings
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
|
||||
# Get all configs for this workspace
|
||||
configs = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id
|
||||
).all()
|
||||
|
||||
|
||||
if not configs:
|
||||
return
|
||||
|
||||
|
||||
# Map of memory_config field -> workspace field
|
||||
model_field_mappings = [
|
||||
("llm_id", "llm"),
|
||||
("embedding_id", "embedding"),
|
||||
("rerank_id", "rerank"),
|
||||
("reflection_model_id", "llm"), # reflection uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
]
|
||||
|
||||
|
||||
configs_updated = 0
|
||||
|
||||
|
||||
for memory_config in configs:
|
||||
updated_fields = []
|
||||
|
||||
|
||||
for config_field, workspace_field in model_field_mappings:
|
||||
config_value = getattr(memory_config, config_field, None)
|
||||
workspace_value = getattr(workspace, workspace_field, None)
|
||||
|
||||
|
||||
if not config_value and workspace_value:
|
||||
setattr(memory_config, config_field, workspace_value)
|
||||
updated_fields.append(config_field)
|
||||
|
||||
|
||||
if updated_fields:
|
||||
configs_updated += 1
|
||||
business_logger.debug(
|
||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||
)
|
||||
|
||||
|
||||
if configs_updated > 0:
|
||||
try:
|
||||
db.commit()
|
||||
@@ -1032,14 +1041,14 @@ def _fill_workspace_configs_model_defaults(
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
pruning_scene_name: Optional[str] = None,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
pruning_scene_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
@@ -1054,9 +1063,9 @@ def _create_default_memory_config(
|
||||
pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
@@ -1070,10 +1079,10 @@ def _create_default_memory_config(
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
|
||||
db.add(default_config)
|
||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
@@ -1084,6 +1093,7 @@ def _create_default_memory_config(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ==================== 检查配置相关服务 ====================
|
||||
|
||||
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
@@ -1096,19 +1106,19 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
workspace: The workspace to check
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
|
||||
# Check if default config exists for this workspace
|
||||
existing_default = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id,
|
||||
MemoryConfig.is_default == True
|
||||
).first()
|
||||
|
||||
|
||||
if not existing_default:
|
||||
# No default config exists, create one
|
||||
business_logger.info(
|
||||
f"Workspace {workspace.id} missing default memory config, creating one"
|
||||
)
|
||||
|
||||
|
||||
# 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景
|
||||
default_scene_id = None
|
||||
try:
|
||||
@@ -1118,7 +1128,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
EMOTIONAL_COMPANION_SCENE,
|
||||
get_scene_name
|
||||
)
|
||||
|
||||
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
# 尝试中文和英文场景名称
|
||||
for language in ["zh", "en"]:
|
||||
@@ -1131,7 +1141,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# 如果教育场景不存在,尝试情感陪伴场景
|
||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
|
||||
@@ -1145,7 +1155,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.warning(
|
||||
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
_create_default_memory_config(
|
||||
db=db,
|
||||
@@ -1160,7 +1170,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.error(
|
||||
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Fill empty model fields for ALL configs in this workspace
|
||||
_fill_workspace_configs_model_defaults(db, workspace)
|
||||
|
||||
@@ -1209,4 +1219,3 @@ def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.error(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
117
api/app/tasks.py
117
api/app/tasks.py
@@ -2228,6 +2228,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
TimeFilterUnavailableError,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
@@ -2256,7 +2257,14 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
_redis_client = get_sync_redis_client()
|
||||
|
||||
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
|
||||
for end_user_id in repo.get_users_needing_refresh(_redis_client, batch_size=100):
|
||||
# Redis 不可用时回退到全量处理
|
||||
try:
|
||||
refresh_iter = repo.get_users_needing_refresh(_redis_client, batch_size=100)
|
||||
except TimeFilterUnavailableError as e:
|
||||
logger.warning(f"时间轴筛选不可用,回退到全量刷新: {e}")
|
||||
refresh_iter = repo.get_all_user_ids(batch_size=100)
|
||||
|
||||
for end_user_id in refresh_iter:
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
@@ -2605,3 +2613,110 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_interest_distribution_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
language = "zh"
|
||||
|
||||
service = MemoryAgentService()
|
||||
|
||||
with get_db_context() as db:
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
|
||||
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
ENABLE_SINGLE_SESSION=
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
||||
MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024
|
||||
FILE_PATH=/files
|
||||
|
||||
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"
|
||||
|
||||
34
api/migrations/versions/fb834419b18f_202603101453.py
Normal file
34
api/migrations/versions/fb834419b18f_202603101453.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""202603101453
|
||||
|
||||
Revision ID: fb834419b18f
|
||||
Revises: 1ac07dc7366f
|
||||
Create Date: 2026-03-10 14:46:48.038643
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fb834419b18f'
|
||||
down_revision: Union[str, None] = '1ac07dc7366f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('end_users', sa.Column('rag_tags', sa.Text(), nullable=True, comment='RAG模式下提取的标签列表(JSON格式)'))
|
||||
op.add_column('end_users', sa.Column('rag_personas', sa.Text(), nullable=True, comment='RAG模式下提取的人物形象列表(JSON格式)'))
|
||||
op.add_column('end_users', sa.Column('rag_summary_updated_at', sa.DateTime(), nullable=True, comment='RAG摘要/标签/人物形象最后更新时间'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('end_users', 'rag_summary_updated_at')
|
||||
op.drop_column('end_users', 'rag_personas')
|
||||
op.drop_column('end_users', 'rag_tags')
|
||||
# ### end Alembic commands ###
|
||||
@@ -145,6 +145,8 @@ dependencies = [
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
"modelscope>=1.34.0",
|
||||
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
||||
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
Submodule redbear-mem-benchmark updated: 8494e82498...c3bbc6931c
@@ -135,4 +135,12 @@ export const getExperienceConfig = (share_token: string) => {
|
||||
'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}`
|
||||
}
|
||||
})
|
||||
}
|
||||
// Export application
|
||||
export const appExport = (app_id: string, appName: string, data?: { release_version: string }) => {
|
||||
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
|
||||
}
|
||||
// Import application
|
||||
export const appImport = (formData: FormData) => {
|
||||
return request.uploadFile(`/apps/import`, formData)
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 10:58:41
|
||||
* @Last Modified time: 2026-03-12 18:25:06
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type {
|
||||
@@ -118,8 +118,9 @@ export const getChunkInsight = (end_user_id: string) => {
|
||||
return request.get(`/dashboard/chunk_insight`, { end_user_id })
|
||||
}
|
||||
// RAG User Memory - Storage content
|
||||
export const getRagContent = (end_user_id: string) => {
|
||||
return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 })
|
||||
export const getRagContentUrl = '/dashboard/rag_content'
|
||||
export const getRagContent = (end_user_id: string, page = 1, pagesize = 20) => {
|
||||
return request.get(getRagContentUrl, { end_user_id, page, pagesize })
|
||||
}
|
||||
// Emotion distribution analysis
|
||||
export const getWordCloud = (end_user_id: string) => {
|
||||
@@ -224,6 +225,10 @@ export const getConversationDetail = (end_user_id: string, conversation_id: stri
|
||||
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
|
||||
return request.post(`/memory/forget-memory/trigger`, data)
|
||||
}
|
||||
// RAG type - Refresh RAG user summary and memory insight
|
||||
export const generateRagProfile = (end_user_id: string) => {
|
||||
return request.post(`/dashboard/generate_rag_profile`, { end_user_id })
|
||||
}
|
||||
/*************** end User Memory APIs ******************************/
|
||||
|
||||
/****************** Memory Management APIs *******************************/
|
||||
|
||||
@@ -41,12 +41,12 @@ export const deleteCompositeModel = (model_id: string) => {
|
||||
return request.delete(`/models/composite/${model_id}`)
|
||||
}
|
||||
// Create API keys for all matching models by provider
|
||||
export const updateProviderApiKeys = (data: KeyConfigModalForm) => {
|
||||
return request.post('/models/provider/apikeys', data)
|
||||
export const updateProviderApiKeys = (data: KeyConfigModalForm, signal?: AbortSignal) => {
|
||||
return request.post('/models/provider/apikeys', data, { signal })
|
||||
}
|
||||
// Create model API key
|
||||
export const addModelApiKey = (model_id: string, data: MultiKeyForm) => {
|
||||
return request.post(`/models/${model_id}/apikeys`, data)
|
||||
export const addModelApiKey = (model_id: string, data: MultiKeyForm, signal?: AbortSignal) => {
|
||||
return request.post(`/models/${model_id}/apikeys`, data, { signal })
|
||||
}
|
||||
// Delete model API key
|
||||
export const deleteModelApiKey = (api_key_id: string) => {
|
||||
@@ -65,10 +65,10 @@ export const addModelPlaza = (model_base_id: string) => {
|
||||
return request.post(`/models/model_plaza/${model_base_id}/add`)
|
||||
}
|
||||
// Create custom model
|
||||
export const addCustomModel = (data: CustomModelForm) => {
|
||||
return request.post('/models', data)
|
||||
export const addCustomModel = (data: CustomModelForm, signal?: AbortSignal) => {
|
||||
return request.post('/models', data, { signal })
|
||||
}
|
||||
// Update custom model
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
|
||||
return request.put(`/models/${model_base_id}`, data)
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm, signal?: AbortSignal) => {
|
||||
return request.put(`/models/${model_base_id}`, data, { signal })
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 15:18:19
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 15:44:42
|
||||
* @Last Modified time: 2026-03-12 18:36:19
|
||||
*/
|
||||
/**
|
||||
* PageScrollList Component
|
||||
@@ -60,8 +60,8 @@ interface PageScrollListProps<T, Q = Record<string, unknown>> {
|
||||
|
||||
/** Infinite scroll list component with pagination support */
|
||||
const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
renderItem,
|
||||
query,
|
||||
renderItem,
|
||||
query,
|
||||
url,
|
||||
column = 4,
|
||||
className = '',
|
||||
@@ -69,68 +69,70 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
}: PageScrollListProps<T, Q>, ref: React.Ref<PageScrollListRef>) => {
|
||||
/** Expose refresh method to parent component */
|
||||
useImperativeHandle(ref, () => ({
|
||||
refresh,
|
||||
refresh: () => {
|
||||
pageRef.current = 1;
|
||||
loadingRef.current = false;
|
||||
setHasMore(true);
|
||||
setData([]);
|
||||
loadMoreData(true);
|
||||
},
|
||||
}));
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [data, setData] = useState<T[]>([]);
|
||||
const [page, setPage] = useState(1);
|
||||
const [hasMore, setHasMore] = useState(true);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const pageRef = useRef(1);
|
||||
const loadingRef = useRef(false);
|
||||
const hasMoreRef = useRef(true);
|
||||
|
||||
/** Load more data from API with pagination */
|
||||
const loadMoreData = (flag?: boolean) => {
|
||||
if (!flag && (loading || !hasMore)) {
|
||||
return;
|
||||
}
|
||||
const loadMoreData = (reset?: boolean) => {
|
||||
if (loadingRef.current || (!reset && !hasMoreRef.current)) return;
|
||||
loadingRef.current = true;
|
||||
setLoading(true);
|
||||
const currentPage = reset ? 1 : pageRef.current;
|
||||
request.get(url, {
|
||||
page: page,
|
||||
page: currentPage,
|
||||
pagesize: PAGE_SIZE,
|
||||
...(query||{}),
|
||||
...(query || {}),
|
||||
})
|
||||
.then((res) => {
|
||||
const response = res as ApiResponse<T>;
|
||||
const results = Array.isArray(response.items) ? response.items : Array.isArray(response) ? response as T[] : [];
|
||||
// Replace data if flag is true, otherwise append
|
||||
if (flag) {
|
||||
setData(results);
|
||||
} else {
|
||||
setData(data.concat(results));
|
||||
}
|
||||
setPage(response.page.page + 1);
|
||||
pageRef.current = response.page.page + 1;
|
||||
setData(prev => reset ? results : [...prev, ...results]);
|
||||
hasMoreRef.current = response.page?.hasnext;
|
||||
setHasMore(response.page?.hasnext);
|
||||
setLoading(false);
|
||||
console.log(`${results.length} more items loaded!`);
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false);
|
||||
hasMoreRef.current = false;
|
||||
setHasMore(false);
|
||||
console.error('Failed to load data');
|
||||
})
|
||||
.finally(() => {
|
||||
loadingRef.current = false;
|
||||
setLoading(false);
|
||||
// 内容不足以填满容器时,主动继续加载
|
||||
setTimeout(() => {
|
||||
const el = scrollRef.current;
|
||||
console.log(el, el?.scrollHeight, el?.clientHeight, hasMoreRef.current)
|
||||
if (el && hasMoreRef.current && el.scrollHeight <= el.clientHeight) {
|
||||
loadMoreData();
|
||||
}
|
||||
}, 0);
|
||||
});
|
||||
};
|
||||
|
||||
/** Reset list to initial state and reload data */
|
||||
const refresh = () => {
|
||||
setPage(1);
|
||||
/** Reset and reload when query parameters change */
|
||||
const queryKey = JSON.stringify(query);
|
||||
useEffect(() => {
|
||||
pageRef.current = 1;
|
||||
loadingRef.current = false;
|
||||
hasMoreRef.current = true;
|
||||
setHasMore(true);
|
||||
setData([]);
|
||||
}
|
||||
loadMoreData(true);
|
||||
}, [queryKey]);
|
||||
|
||||
/** Refresh when query parameters change */
|
||||
useEffect(() => {
|
||||
refresh()
|
||||
}, [query]);
|
||||
|
||||
/** Load initial data when list is reset */
|
||||
useEffect(() => {
|
||||
if (page === 1 && hasMore && data.length === 0) {
|
||||
loadMoreData(true);
|
||||
}
|
||||
}, [page, hasMore, data])
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
@@ -140,7 +142,7 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
>
|
||||
<InfiniteScroll
|
||||
dataLength={data.length}
|
||||
next={loadMoreData}
|
||||
next={() => loadMoreData()}
|
||||
hasMore={hasMore}
|
||||
loader={loading && needLoading ? <PageLoading /> : false}
|
||||
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}
|
||||
|
||||
@@ -1370,7 +1370,7 @@ export const en = {
|
||||
gotoList: 'Return to Application List',
|
||||
gotoDetail: 'View Details',
|
||||
dify: 'Dify',
|
||||
pleaseUploadFile: 'Please upload workflow file',
|
||||
pleaseUploadFile: 'Please upload file',
|
||||
},
|
||||
userMemory: {
|
||||
userMemory: 'User Memory',
|
||||
@@ -1820,6 +1820,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
marketRefresh: 'Refresh',
|
||||
marketConfigBtn: 'Configure',
|
||||
marketConfigConnection: 'Configure Connection',
|
||||
marketNoData: 'No Data',
|
||||
marketNoDataDesc: 'This market currently has no available services',
|
||||
marketNoSearchResult: 'No Search Results',
|
||||
marketNoSearchResultDesc: 'No matching services found, please try other keywords',
|
||||
marketNoServices: 'No MCP Services Available',
|
||||
marketNotConnected: 'Not Connected to This Market',
|
||||
marketNoServicesDesc: 'This market currently has no available services',
|
||||
@@ -2040,6 +2044,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
self_optimization: 'Self Optimization',
|
||||
process_evolution: 'Process Evolution',
|
||||
unknown: 'Unknown Node',
|
||||
notes: 'Sticky Note',
|
||||
|
||||
clickToConfigure: 'Click to configure node parameters',
|
||||
nodeProperties: 'Node Properties',
|
||||
@@ -2227,6 +2232,12 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
output_variables: 'Output Variables',
|
||||
refreshTip: 'Sync function signature to code',
|
||||
},
|
||||
notes: {
|
||||
showAuth: 'Show Author',
|
||||
enterLink: 'Enter Link URL',
|
||||
placeholder: 'Enter note...',
|
||||
removeLink: 'Remove Link',
|
||||
},
|
||||
name: 'Key',
|
||||
type: 'Type',
|
||||
value: 'Value',
|
||||
|
||||
@@ -754,7 +754,7 @@ export const zh = {
|
||||
gotoList: '返回应用列表',
|
||||
gotoDetail: '查看详情',
|
||||
dify: 'Dify',
|
||||
pleaseUploadFile: '请上传工作流文件',
|
||||
pleaseUploadFile: '请上传文件',
|
||||
},
|
||||
table: {
|
||||
totalRecords: '共 {{total}} 条记录'
|
||||
@@ -1816,6 +1816,10 @@ export const zh = {
|
||||
marketRefresh: '刷新',
|
||||
marketConfigBtn: '配置',
|
||||
marketConfigConnection: '配置连接',
|
||||
marketNoData: '暂无数据',
|
||||
marketNoDataDesc: '该市场暂时没有可用的服务',
|
||||
marketNoSearchResult: '无搜索结果',
|
||||
marketNoSearchResultDesc: '未找到匹配的服务,请尝试其他关键词',
|
||||
marketNoServices: '暂无可用的 MCP 服务',
|
||||
marketNotConnected: '尚未连接此市场',
|
||||
marketNoServicesDesc: '该市场暂时没有可用的服务',
|
||||
@@ -2036,6 +2040,7 @@ export const zh = {
|
||||
self_optimization: '自我优化',
|
||||
process_evolution: '流程演化',
|
||||
unknown: '未知节点',
|
||||
notes: '便签',
|
||||
|
||||
clickToConfigure: '点击配置节点参数',
|
||||
nodeProperties: '节点属性',
|
||||
@@ -2226,6 +2231,12 @@ export const zh = {
|
||||
unknown: {
|
||||
replaceNodeType: '替换节点'
|
||||
},
|
||||
notes: {
|
||||
showAuth: '显示作者',
|
||||
enterLink: '输入链接 URL',
|
||||
placeholder: '输入注释...',
|
||||
removeLink: '取消链接',
|
||||
},
|
||||
name: '键',
|
||||
type: '类型',
|
||||
value: '值',
|
||||
|
||||
@@ -347,6 +347,23 @@ export const request = {
|
||||
document.body.removeChild(link);
|
||||
callback?.()
|
||||
});
|
||||
},
|
||||
getDownloadFile(url: string, fileName: string, data?: unknown, callback?: () => void) {
|
||||
service.get(url, {
|
||||
params: paramFilter(data as Record<string, string | number | boolean | ObjectWithPush | null | undefined>),
|
||||
responseType: "blob",
|
||||
})
|
||||
.then(res => {
|
||||
const link = document.createElement("a");
|
||||
const blob = new Blob([res as unknown as BlobPart]);
|
||||
link.style.display = "none";
|
||||
link.href = URL.createObjectURL(blob);
|
||||
link.setAttribute("download", decodeURI(fileName || fileName));
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
callback?.()
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Button, Space, Input, Form, App } from 'antd';
|
||||
|
||||
import Tag, { type TagProps } from './components/Tag'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getReleaseList, rollbackRelease } from '@/api/application'
|
||||
import { getReleaseList, rollbackRelease, appExport } from '@/api/application'
|
||||
import ReleaseModal from './components/ReleaseModal'
|
||||
import ReleaseShareModal from './components/ReleaseShareModal'
|
||||
import type { Release, ReleaseModalRef, ReleaseShareModalRef } from './types'
|
||||
@@ -67,6 +67,9 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
|
||||
message.success(t('common.operateSuccess'))
|
||||
})
|
||||
}
|
||||
const handleExport = () => {
|
||||
appExport(data.id, data.name)
|
||||
}
|
||||
return (
|
||||
<div className="rb:flex rb:h-[calc(100vh-64px)]">
|
||||
<div className="rb:h-full rb:overflow-y-auto rb:w-108 rb:flex-[0_0_auto] rb:border-r rb:border-[#DFE4ED] rb:p-4">
|
||||
@@ -123,7 +126,7 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
|
||||
|
||||
<Space size={10}>
|
||||
{selectedVersion && <>
|
||||
{/* <Button>{t('application.exportDSLFile')}</Button> */}
|
||||
{data?.type !== 'multi_agent' && <Button onClick={handleExport}>{t('common.export')}</Button>}
|
||||
{data.current_release_id !== selectedVersion.id && <Button onClick={handleRollback}>{t('application.willRollToThisVersion')}</Button>}
|
||||
<Button type="primary" ghost onClick={() => releaseShareModalRef.current?.handleOpen()}>{t('application.share')}</Button>
|
||||
</>}
|
||||
|
||||
@@ -19,9 +19,8 @@ import deleteIcon from '@/assets/images/delete_hover.svg'
|
||||
import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types';
|
||||
import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal'
|
||||
import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types'
|
||||
import { deleteApplication } from '@/api/application'
|
||||
import { deleteApplication, appExport } from '@/api/application'
|
||||
import CopyModal from './CopyModal'
|
||||
import { exportToYaml } from '@/utils/yamlExport';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -85,16 +84,16 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
* Handle menu item click
|
||||
*/
|
||||
const handleClick: MenuProps['onClick'] = ({ key }) => {
|
||||
if (!application) return
|
||||
switch (key) {
|
||||
case 'edit':
|
||||
applicationModalRef.current?.handleOpen(application as Application)
|
||||
applicationModalRef.current?.handleOpen(application)
|
||||
break;
|
||||
case 'copy':
|
||||
copyModalRef.current?.handleOpen()
|
||||
break;
|
||||
case 'export':
|
||||
console.log('export', workflowRef?.current?.config)
|
||||
exportToYaml(workflowRef?.current?.config, application?.name ?`${application?.name}.yml`: undefined)
|
||||
appExport(application.id, application.name)
|
||||
break;
|
||||
case 'delete':
|
||||
handleDelete()
|
||||
@@ -153,7 +152,7 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
* Format dropdown menu items
|
||||
*/
|
||||
const formatMenuItems = useMemo(() => {
|
||||
const items = (application?.type === 'workflow' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
|
||||
const items = (application?.type !== 'multi_agent' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
|
||||
key,
|
||||
icon: <img src={menuIcons[key]} className="rb:w-4 rb:h-4 rb:mr-2" />,
|
||||
label: t(`common.${key}`),
|
||||
|
||||
256
web/src/views/ApplicationManagement/components/UploadModal.tsx
Normal file
256
web/src/views/ApplicationManagement/components/UploadModal.tsx
Normal file
@@ -0,0 +1,256 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-28 14:08:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 17:19:46
|
||||
*/
|
||||
/**
|
||||
* UploadModal Component
|
||||
*
|
||||
* This component provides a modal for uploading workflow files with a multi-step process:
|
||||
* 1. Upload - Select platform and file
|
||||
* 2. Complex - Show warnings and errors if any
|
||||
* 3. SureInfo - Confirm and edit workflow information
|
||||
* 4. Completed - Show success message and options
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
|
||||
import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { Application, UploadModalRef } from '../types'
|
||||
import RbModal from '@/components/RbModal'
|
||||
import UploadFiles from '@/components/Upload/UploadFiles'
|
||||
import { appImport } from '@/api/application'
|
||||
|
||||
/**
|
||||
* Props for UploadModal component
|
||||
*/
|
||||
interface UploadModalProps {
|
||||
/** Function to refresh the parent component after workflow import */
|
||||
refresh: () => void;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Steps definition for the upload process
|
||||
*/
|
||||
const steps = [
|
||||
'upload', // Step 1: File upload
|
||||
'complex', // Step 2: Error/warning display
|
||||
'completed' // Step 4: Success message
|
||||
]
|
||||
/**
|
||||
* UploadModal component
|
||||
*
|
||||
* @param {UploadModalProps} props - Component props
|
||||
* @param {React.Ref<UploadModalRef>} ref - Ref for imperative methods
|
||||
*/
|
||||
const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
// State management
|
||||
const [visible, setVisible] = useState(false); // Modal visibility
|
||||
const [form] = Form.useForm<{ file: File[] }>(); // Form instance
|
||||
const [loading, setLoading] = useState(false); // Loading state
|
||||
const [current, setCurrent] = useState<number>(0); // Current step
|
||||
const [appId, setAppId] = useState<string | null>(null); // Imported application ID
|
||||
const [warnings, setWarnings] = useState<string[]>([])
|
||||
|
||||
/**
|
||||
* Handle modal close
|
||||
* Resets all states and form fields
|
||||
*/
|
||||
const handleClose = () => {
|
||||
refresh()
|
||||
setVisible(false);
|
||||
form.resetFields();
|
||||
setCurrent(0);
|
||||
setAppId(null);
|
||||
setLoading(false);
|
||||
setWarnings([])
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle modal open
|
||||
* Resets form fields and shows modal
|
||||
*/
|
||||
const handleOpen = () => {
|
||||
form.resetFields();
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle save/submit action
|
||||
* Processes different logic based on current step
|
||||
*/
|
||||
const handleSave = () => {
|
||||
const values = form.getFieldsValue();
|
||||
|
||||
switch(current) {
|
||||
case 0: // Step 1: Upload file
|
||||
if (!values.file || values.file.length === 0) {
|
||||
message.warning(t('application.pleaseUploadFile'));
|
||||
return;
|
||||
}
|
||||
const formData = new FormData();
|
||||
formData.append('file', values.file[0]);
|
||||
|
||||
setLoading(true)
|
||||
// Call import API
|
||||
appImport(formData)
|
||||
.then(res => {
|
||||
const { warnings, app } = res as { warnings: string[]; app: Application };
|
||||
|
||||
setAppId(app?.id)
|
||||
if (warnings.length) {
|
||||
setCurrent(1)
|
||||
setWarnings(warnings)
|
||||
} else {
|
||||
setCurrent(2)
|
||||
}
|
||||
})
|
||||
.finally(() => setLoading(false));
|
||||
break;
|
||||
case 2:
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Expose methods to parent component via ref
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
/**
|
||||
* Handle navigation after successful import
|
||||
* @param {string} type - Navigation type ('detail' or 'list')
|
||||
*/
|
||||
const handleJump = (type: string) => {
|
||||
handleClose();
|
||||
refresh();
|
||||
setTimeout(() => {
|
||||
switch (type) {
|
||||
case 'detail':
|
||||
// Open application detail page in new tab
|
||||
window.open(`/#/application/config/${appId}`, '_blank');
|
||||
break;
|
||||
}
|
||||
}, 100)
|
||||
};
|
||||
|
||||
/**
|
||||
* Generate modal footer based on current step
|
||||
*/
|
||||
const getFooter = useMemo(() => {
|
||||
switch (current) {
|
||||
case 0: // Step 1: Upload
|
||||
return [
|
||||
<Button key="back" onClick={handleClose}>
|
||||
{t('common.cancel')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="confirm"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={handleSave}
|
||||
>
|
||||
{t('common.confirm')}
|
||||
</Button>
|
||||
];
|
||||
case 1:
|
||||
return [
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
{t('application.gotoList')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={() => handleJump('detail')}
|
||||
>
|
||||
{t('application.gotoDetail')}
|
||||
</Button>
|
||||
]
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}, [current, loading]);
|
||||
return (
|
||||
<RbModal
|
||||
title={t('application.import')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t('common.confirm')}
|
||||
onOk={handleSave}
|
||||
footer={getFooter}
|
||||
>
|
||||
{/* Steps indicator */}
|
||||
<div className='rb:p-3 rb:bg-[#FBFDFF] rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:mb-3'>
|
||||
<Steps
|
||||
labelPlacement="vertical"
|
||||
size="small"
|
||||
current={current}
|
||||
items={steps.map(key => ({ title: t(`application.${key}`) }))}
|
||||
/>
|
||||
</div>
|
||||
{current === 0 &&
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="file"
|
||||
valuePropName="fileList"
|
||||
noStyle
|
||||
>
|
||||
<UploadFiles
|
||||
isAutoUpload={false}
|
||||
isCanDrag={true}
|
||||
fileSize={100}
|
||||
maxCount={1}
|
||||
fileType={['yml']}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
}
|
||||
{/* Step 2: Error/warning display */}
|
||||
{current === 1 &&
|
||||
<Flex vertical gap={12}>
|
||||
{warnings.map((vo, index) => (
|
||||
<Alert
|
||||
key={index}
|
||||
message={<div>{vo}</div>}
|
||||
type="warning"
|
||||
showIcon
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
}
|
||||
{current === 2 &&
|
||||
<Result
|
||||
status="success"
|
||||
title={t('application.importSuccess')}
|
||||
subTitle={t('application.importSuccessDesc')}
|
||||
extra={[
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
{t('application.gotoList')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={() => handleJump('detail')}
|
||||
>
|
||||
{t('application.gotoDetail')}
|
||||
</Button>
|
||||
]}
|
||||
/>
|
||||
}
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default UploadModal;
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-28 14:08:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-06 12:05:46
|
||||
* @Last Modified time: 2026-03-12 17:19:33
|
||||
*/
|
||||
/**
|
||||
* UploadWorkflowModal Component
|
||||
@@ -72,6 +72,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
setFirstFormData(null);
|
||||
setAppId(null);
|
||||
setLoading(false);
|
||||
refresh()
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -25,6 +25,7 @@ import { getApplicationListUrl, deleteApplication } from '@/api/application'
|
||||
import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList'
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import UploadWorkflowModal from './components/UploadWorkflowModal'
|
||||
import UploadModal from './components/UploadModal'
|
||||
|
||||
/**
|
||||
* Application management main component
|
||||
@@ -37,6 +38,7 @@ const ApplicationManagement: React.FC = () => {
|
||||
const applicationModalRef = useRef<ApplicationModalRef>(null);
|
||||
const scrollListRef = useRef<PageScrollListRef>(null)
|
||||
const uploadWorkflowModalRef = useRef<UploadWorkflowModalRef>(null);
|
||||
const uploadModalRef = useRef<UploadWorkflowModalRef>(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Convert URLSearchParams to a plain object for easier access
|
||||
@@ -91,6 +93,8 @@ const ApplicationManagement: React.FC = () => {
|
||||
case 'thirdParty':
|
||||
handleImport()
|
||||
break;
|
||||
case 'import':
|
||||
uploadModalRef.current?.handleOpen()
|
||||
}
|
||||
}
|
||||
return (
|
||||
@@ -121,6 +125,7 @@ const ApplicationManagement: React.FC = () => {
|
||||
<Dropdown
|
||||
menu={{ items: [
|
||||
{ key: 'thirdParty', label: t('application.importWorkflow') },
|
||||
{ key: 'import', label: t('application.import') },
|
||||
], onClick: handleClick }}
|
||||
placement="bottomRight"
|
||||
>
|
||||
@@ -186,6 +191,10 @@ const ApplicationManagement: React.FC = () => {
|
||||
ref={uploadWorkflowModalRef}
|
||||
refresh={refresh}
|
||||
/>
|
||||
<UploadModal
|
||||
ref={uploadModalRef}
|
||||
refresh={refresh}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -233,4 +233,12 @@ export interface UploadData extends WorkflowConfig {
|
||||
export interface UploadWorkflowModalRef {
|
||||
/** Open the upload workflow modal */
|
||||
handleOpen: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload app modal ref interface
|
||||
*/
|
||||
export interface UploadModalRef {
|
||||
/** Open the upload workflow modal */
|
||||
handleOpen: () => void;
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:28
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 11:31:43
|
||||
* @Last Modified time: 2026-03-11 15:08:24
|
||||
*/
|
||||
/**
|
||||
* Custom Model Modal
|
||||
@@ -11,7 +11,7 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App, Checkbox } from 'antd';
|
||||
import { Form, Input, App, Checkbox, Button } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types';
|
||||
@@ -35,6 +35,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<CustomModelForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
const modelType = Form.useWatch(['type'], form);
|
||||
const isOmni = Form.useWatch(['is_omni'], form);
|
||||
|
||||
@@ -46,6 +47,8 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
|
||||
/** Close modal and reset state */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ModelListItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
@@ -73,8 +76,10 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
/** Update or create custom model */
|
||||
const handleUpdate = (data: CustomModelForm) => {
|
||||
setLoading(true)
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
const { type, provider, ...rest} = data
|
||||
const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data)
|
||||
const res = isEdit ? updateCustomModel(model.id, rest, controller.signal) : addCustomModel(data, controller.signal)
|
||||
|
||||
res.then(() => {
|
||||
refresh?.(isEdit)
|
||||
@@ -124,15 +129,15 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
console.log('modelType', modelType)
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createCustomModel')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
|
||||
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.${isEdit ? 'save' : 'create'}`)}</Button>,
|
||||
]}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:40
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:49:40
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-11 15:12:17
|
||||
*/
|
||||
/**
|
||||
* Key Configuration Modal
|
||||
@@ -11,7 +11,7 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { Form, Input, App, Button } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { KeyConfigModalForm, ProviderModelItem, KeyConfigModalRef, KeyConfigModalProps } from '../types';
|
||||
@@ -30,9 +30,12 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
const [model, setModel] = useState<ProviderModelItem>({} as ProviderModelItem);
|
||||
const [form] = Form.useForm<KeyConfigModalForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
|
||||
/** Close modal and reset state */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ProviderModelItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
@@ -51,10 +54,13 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
|
||||
updateProviderApiKeys({
|
||||
...values,
|
||||
provider: model.provider
|
||||
}).then((res) => {
|
||||
}, controller.signal).then((res) => {
|
||||
if (refresh) {
|
||||
refresh();
|
||||
}
|
||||
@@ -81,9 +87,10 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
title={`${model.provider} - ${t('modelNew.keyConfig')}`}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.save`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
|
||||
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.save`)}</Button>,
|
||||
]}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:55
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:49:55
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-11 15:11:06
|
||||
*/
|
||||
/**
|
||||
* Multi-Key Configuration Modal
|
||||
@@ -28,9 +28,12 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
|
||||
const [form] = Form.useForm<MultiKeyForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
|
||||
/** Close modal and refresh parent */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ModelListItem);
|
||||
refresh?.()
|
||||
|
||||
@@ -60,12 +63,14 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
addModelApiKey(model.id, {
|
||||
...values,
|
||||
model_config_id: model.id,
|
||||
model_name: model.name,
|
||||
provider: model.provider,
|
||||
}).then(() => {
|
||||
}, controller.signal).then(() => {
|
||||
message.success(t('common.saveSuccess'))
|
||||
form.resetFields();
|
||||
getData(model)
|
||||
@@ -98,7 +103,6 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
footer={null}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
{model.api_keys && model.api_keys.length > 0 && (
|
||||
<div className="rb:mb-4">
|
||||
|
||||
@@ -395,6 +395,7 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
{t('tool.marketRefresh')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
<Input
|
||||
prefix={<SearchOutlined />}
|
||||
placeholder={t('tool.marketSearchPlaceholder')}
|
||||
@@ -402,7 +403,9 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
onChange={(e) => handleSearchChange(e.target.value)}
|
||||
allowClear
|
||||
style={{ width: 200 }}
|
||||
|
||||
/>
|
||||
|
||||
</div>
|
||||
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
||||
{t('tool.marketConfigBtn')}
|
||||
@@ -559,9 +562,9 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
<span className="rb:flex-1 rb:font-medium rb:text-[12px] rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||
{source.name}
|
||||
</span>
|
||||
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
|
||||
{/* <span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
|
||||
{source.mcp_count}
|
||||
</span>
|
||||
</span> */}
|
||||
{source.connected && (
|
||||
<span className="rb:text-green-500 rb:text-[8px] rb:flex-shrink-0">●</span>
|
||||
)}
|
||||
|
||||
@@ -132,12 +132,14 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
const request = editVo?.id ? updateTool(editVo.id, newService) : addTool(newService)
|
||||
request.then((res: any) => {
|
||||
message.success(t('common.saveSuccess'));
|
||||
testConnection(res.tool_id || editVo?.id)
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
handleClose();
|
||||
refresh()
|
||||
})
|
||||
setLoading(false);
|
||||
handleClose();
|
||||
refresh();
|
||||
|
||||
// 在后台测试连接,不阻塞用户操作
|
||||
testConnection(res.tool_id || editVo?.id).catch((err) => {
|
||||
console.error('测试连接失败:', err);
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false);
|
||||
|
||||
@@ -146,4 +146,5 @@ export interface MarketQuery {
|
||||
mcp_market_config_id?: string;
|
||||
page?: number;
|
||||
pagesize?: number;
|
||||
keywords?: string;
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:57:11
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:57:11
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 18:00:11
|
||||
*/
|
||||
/**
|
||||
* RAG User Memory Detail View
|
||||
@@ -13,7 +13,8 @@
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import clsx from 'clsx'
|
||||
import { Row, Col, Skeleton } from 'antd'
|
||||
import { Row, Col, Skeleton, Spin, Flex, Tooltip } from 'antd'
|
||||
import { LoadingOutlined } from '@ant-design/icons';
|
||||
import { useParams } from 'react-router-dom'
|
||||
|
||||
import aboutUs from '@/assets/images/userMemory/aboutUs.svg'
|
||||
@@ -26,6 +27,7 @@ import {
|
||||
getUserProfile,
|
||||
getTotalRagMemoryCountByUser,
|
||||
getChunkInsight,
|
||||
generateRagProfile
|
||||
} from '@/api/memory'
|
||||
import Empty from '@/components/Empty'
|
||||
import ConversationMemory from './components/ConversationMemory'
|
||||
@@ -133,16 +135,46 @@ const Rag: FC = () => {
|
||||
})
|
||||
}
|
||||
const name = loading.detail ? '' : data?.name && data?.name !== '' ? data.name : id
|
||||
|
||||
const [refreshLoading, setRefreshLoading] = useState(false)
|
||||
const handleRefresh = () => {
|
||||
if (refreshLoading || !id) return
|
||||
setRefreshLoading(true)
|
||||
generateRagProfile(id as string)
|
||||
.then(() => {
|
||||
getSummary()
|
||||
getInsightReport()
|
||||
})
|
||||
.finally(() => {
|
||||
setRefreshLoading(false)
|
||||
})
|
||||
}
|
||||
return (
|
||||
<Row gutter={[16, 16]} className="rb:pb-6">
|
||||
<Row gutter={[16, 16]} className="rb:h-full!">
|
||||
<Col span={8}>
|
||||
<RbCard>
|
||||
<RbCard
|
||||
className="rb:h-[calc(100vh-104px)]!"
|
||||
bodyClassName="rb:overflow-y-auto! rb:h-full!"
|
||||
>
|
||||
<div className="rb:flex rb:items-center">
|
||||
<div className="rb:flex-[0_0_auto] rb:w-20 rb:h-20 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-20 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF]">{name?.[0]}</div>
|
||||
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4">
|
||||
{name}<br/>
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
|
||||
</div>
|
||||
<Flex>
|
||||
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4 rb:flex-1">
|
||||
{name}<br />
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
|
||||
</div>
|
||||
<Tooltip title={t('common.refresh')}>
|
||||
{refreshLoading
|
||||
? <Spin indicator={<LoadingOutlined spin />} />
|
||||
: (
|
||||
<div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/refresh.svg')] rb:hover:bg-[url('@/assets/images/refresh_hover.svg')]"
|
||||
onClick={handleRefresh}
|
||||
></div>
|
||||
)
|
||||
}
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</div>
|
||||
|
||||
<div className="rb:flex rb:gap-2 rb:mb-2 rb:flex-wrap rb:mt-6.25">
|
||||
|
||||
@@ -1,74 +1,43 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 18:34:04
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 18:34:04
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 18:34:52
|
||||
*/
|
||||
/**
|
||||
* Conversation Memory Component
|
||||
* Displays RAG conversation memory content list
|
||||
*/
|
||||
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { Skeleton, List } from 'antd';
|
||||
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import Empty from '@/components/Empty';
|
||||
import PageScrollList from '@/components/PageScrollList'
|
||||
import Markdown from '@/components/Markdown'
|
||||
import {
|
||||
getRagContent
|
||||
} from '@/api/memory'
|
||||
import { getRagContentUrl } from '@/api/memory'
|
||||
|
||||
const ConversationMemory:FC = () => {
|
||||
const ConversationMemory: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { id } = useParams()
|
||||
const [loading, setLoading] = useState<boolean>(true)
|
||||
const [list, setList] = useState<string[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
if (!id) return
|
||||
getList()
|
||||
}, [id])
|
||||
/** Fetch conversation memory list */
|
||||
const getList = () => {
|
||||
if (!id) return
|
||||
setLoading(true)
|
||||
getRagContent(id).then((res) => {
|
||||
setList((res as { contents?: [] }).contents || [])
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<RbCard
|
||||
<RbCard
|
||||
title={t('userMemory.conversationMemory')}
|
||||
headerClassName="rb:text-[18px]! rb:leading-[24px]"
|
||||
bodyClassName="rb:h-[100%]! rb:overflow-hidden rb:py-0!"
|
||||
bodyClassName="rb:h-[calc(100%-56px)]! rb:overflow-hidden"
|
||||
className="rb:h-[calc(100vh-104px)]!"
|
||||
>
|
||||
{loading
|
||||
? <Skeleton />
|
||||
: list.length > 0
|
||||
? <List
|
||||
dataSource={list}
|
||||
grid={{ gutter: 12, column: 1 }}
|
||||
renderItem={(item, index) => (
|
||||
<List.Item>
|
||||
<div
|
||||
key={index}
|
||||
className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:mt-2 rb:text-gray-800 rb:text-sm"
|
||||
>
|
||||
<Markdown content={item} />
|
||||
</div>
|
||||
</List.Item>
|
||||
)}
|
||||
/>
|
||||
: <Empty className="rb:h-full" />
|
||||
}
|
||||
<PageScrollList<string>
|
||||
url={getRagContentUrl}
|
||||
query={{ end_user_id: id }}
|
||||
column={1}
|
||||
renderItem={(item) => (
|
||||
<div className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:text-gray-800 rb:text-sm">
|
||||
<Markdown content={item} />
|
||||
</div>
|
||||
)}
|
||||
className="rb:h-full!"
|
||||
// className="rb:h-[calc(100%-24px)]!"
|
||||
/>
|
||||
</RbCard>
|
||||
)
|
||||
}
|
||||
export default ConversationMemory
|
||||
|
||||
export default ConversationMemory
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { FC } from 'react';
|
||||
import { Select } from 'antd';
|
||||
import { Select, Divider } from 'antd';
|
||||
// import { Node } from '@antv/x6';
|
||||
import type { GraphRef } from '../types'
|
||||
import { PlusOutlined, MinusOutlined } from '@ant-design/icons'
|
||||
import { PlusOutlined, MinusOutlined, FileAddOutlined } from '@ant-design/icons'
|
||||
|
||||
interface CanvasToolbarProps {
|
||||
miniMapRef: React.RefObject<HTMLDivElement>;
|
||||
@@ -14,6 +14,7 @@ interface CanvasToolbarProps {
|
||||
canRedo: boolean;
|
||||
onUndo: () => void;
|
||||
onRedo: () => void;
|
||||
addNotes: () => void;
|
||||
}
|
||||
|
||||
const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
@@ -26,6 +27,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
// canRedo,
|
||||
// onUndo,
|
||||
// onRedo,
|
||||
addNotes,
|
||||
}) => {
|
||||
// 整理布局函数
|
||||
/*
|
||||
@@ -152,7 +154,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
{/* 小地图 */}
|
||||
<div ref={miniMapRef} className="rb:absolute rb:bottom-15 rb:right-8 rb:z-1000 rb:rounded-lg rb:overflow-hidden"></div>
|
||||
{/* 缩放控制按钮 */}
|
||||
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:gap-4 rb:z-1000">
|
||||
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:items-center rb:gap-4 rb:z-1000">
|
||||
<MinusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(-0.1)} />
|
||||
<Select
|
||||
value={Math.round(zoomLevel * 100)}
|
||||
@@ -182,6 +184,8 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
size="small"
|
||||
/>
|
||||
<PlusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(0.1)} />
|
||||
<Divider type="vertical" className="rb:h-4" />
|
||||
<FileAddOutlined onClick={addNotes} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { FORMAT_TEXT_COMMAND, $getSelection, $isRangeSelection, $setSelection, $isTextNode, type BaseSelection } from 'lexical';
|
||||
import { $patchStyleText } from '@lexical/selection';
|
||||
import { INSERT_UNORDERED_LIST_COMMAND, REMOVE_LIST_COMMAND, ListNode } from '@lexical/list';
|
||||
import { TOGGLE_LINK_COMMAND, LinkNode } from '@lexical/link';
|
||||
import { $getNearestNodeOfType } from '@lexical/utils';
|
||||
|
||||
export const NOTE_FORMAT_EVENT = 'note:format';
|
||||
|
||||
export interface FormatState {
|
||||
bold: boolean;
|
||||
italic: boolean;
|
||||
strikethrough: boolean;
|
||||
list: boolean;
|
||||
fontSize?: number;
|
||||
linkUrl?: string | null;
|
||||
}
|
||||
|
||||
const NoteFormatPlugin = ({ nodeId, onFormatChange, fontSize = 12 }: { nodeId: string; fontSize?: number; onFormatChange?: (state: FormatState) => void }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const savedSelection = useRef<BaseSelection | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerUpdateListener(({ editorState }) => {
|
||||
editorState.read(() => {
|
||||
const selection = $getSelection();
|
||||
if (!$isRangeSelection(selection)) return;
|
||||
savedSelection.current = selection.clone();
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const style = 'getStyle' in anchorNode ? (anchorNode as { getStyle(): string }).getStyle() : '';
|
||||
const match = style.match(/font-size:\s*([\d.]+)px/);
|
||||
const nodeFontSize = match ? Number(match[1]) : fontSize;
|
||||
const linkNode = $getNearestNodeOfType(anchorNode, LinkNode);
|
||||
onFormatChange?.({
|
||||
bold: selection.hasFormat('bold'),
|
||||
italic: selection.hasFormat('italic'),
|
||||
strikethrough: selection.hasFormat('strikethrough'),
|
||||
list: !!$getNearestNodeOfType(anchorNode, ListNode),
|
||||
...(nodeFontSize ? { fontSize: nodeFontSize } : {}),
|
||||
linkUrl: linkNode ? linkNode.getURL() : null,
|
||||
});
|
||||
});
|
||||
});
|
||||
}, [editor, onFormatChange]);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: Event) => {
|
||||
const { id, format, value } = (e as CustomEvent).detail;
|
||||
if (id !== nodeId) return;
|
||||
const sel = savedSelection.current;
|
||||
const hasSelection = $isRangeSelection(sel) && !sel.isCollapsed();
|
||||
if (format === 'link' && value === null) {
|
||||
// remove link: select the entire LinkNode first
|
||||
editor.focus(() => {
|
||||
editor.update(() => {
|
||||
const s = $getSelection();
|
||||
const anchorNode = $isRangeSelection(s)
|
||||
? s.anchor.getNode()
|
||||
: savedSelection.current && $isRangeSelection(savedSelection.current)
|
||||
? savedSelection.current.anchor.getNode()
|
||||
: null;
|
||||
const linkNode = anchorNode ? $getNearestNodeOfType(anchorNode, LinkNode) : null;
|
||||
if (linkNode) {
|
||||
const children = linkNode.getChildren();
|
||||
if (children.length > 0) {
|
||||
const first = children[0];
|
||||
const last = children[children.length - 1];
|
||||
if ($isTextNode(first) && $isTextNode(last)) {
|
||||
const range = first.select(0, 0);
|
||||
range.focus.set(last.getKey(), last.getTextContentSize(), 'text');
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
editor.dispatchCommand(TOGGLE_LINK_COMMAND, null);
|
||||
});
|
||||
} else if (format === 'list') {
|
||||
editor.focus(() => {
|
||||
if (sel) editor.update(() => $setSelection(sel));
|
||||
editor.dispatchCommand(value ? INSERT_UNORDERED_LIST_COMMAND : REMOVE_LIST_COMMAND, undefined);
|
||||
editor.update(() => $setSelection(null));
|
||||
});
|
||||
} else if (hasSelection) {
|
||||
editor.focus(() => {
|
||||
editor.update(() => $setSelection(sel));
|
||||
if (format === 'bold' || format === 'italic' || format === 'strikethrough') {
|
||||
editor.dispatchCommand(FORMAT_TEXT_COMMAND, format);
|
||||
} else if (format === 'link') {
|
||||
editor.dispatchCommand(TOGGLE_LINK_COMMAND, value as string | null);
|
||||
} else if (format === 'fontSize') {
|
||||
editor.update(() => {
|
||||
$setSelection(sel);
|
||||
$patchStyleText(sel!, { 'font-size': `${value}px` });
|
||||
});
|
||||
}
|
||||
editor.update(() => $setSelection(null));
|
||||
});
|
||||
}
|
||||
};
|
||||
window.addEventListener(NOTE_FORMAT_EVENT, handler);
|
||||
return () => window.removeEventListener(NOTE_FORMAT_EVENT, handler);
|
||||
}, [editor, nodeId]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default NoteFormatPlugin;
|
||||
@@ -0,0 +1,74 @@
|
||||
import { type FC, useState } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Flex, Button, Input } from 'antd';
|
||||
import { EditOutlined, DisconnectOutlined } from '@ant-design/icons';
|
||||
|
||||
const POPOVER_STYLE: React.CSSProperties = {
|
||||
position: 'fixed',
|
||||
zIndex: 1000,
|
||||
background: '#fff',
|
||||
border: '1px solid #e5e7eb',
|
||||
borderRadius: 8,
|
||||
boxShadow: '0 2px 8px rgba(0,0,0,0.12)',
|
||||
whiteSpace: 'nowrap',
|
||||
};
|
||||
|
||||
interface LinkPopoverProps {
|
||||
url: string;
|
||||
rect: DOMRect;
|
||||
onEdit: () => void;
|
||||
onRemove: () => void;
|
||||
}
|
||||
|
||||
export const LinkPopover: FC<LinkPopoverProps> = ({ url, rect, onEdit, onRemove }) => {
|
||||
const { t } = useTranslation();
|
||||
return createPortal(
|
||||
<div
|
||||
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '4px 10px', fontSize: 12 }}
|
||||
onMouseDown={e => e.stopPropagation()}
|
||||
>
|
||||
<Flex align="center" gap={8}>
|
||||
<a href={url} target="_blank" rel="noreferrer" style={{ color: '#2563eb', maxWidth: 160, overflow: 'hidden', textOverflow: 'ellipsis', display: 'inline-block' }}>
|
||||
{url}
|
||||
</a>
|
||||
<Button size="small" type="text" icon={<EditOutlined />} onClick={onEdit}>{t('common.edit')}</Button>
|
||||
<Button size="small" type="text" icon={<DisconnectOutlined />} onClick={onRemove}>{t('workflow.config.notes.removeLink')}</Button>
|
||||
</Flex>
|
||||
</div>,
|
||||
document.body
|
||||
);
|
||||
};
|
||||
|
||||
interface EditLinkPopoverProps {
|
||||
rect: DOMRect;
|
||||
initialUrl: string;
|
||||
onConfirm: (url: string) => void;
|
||||
}
|
||||
|
||||
export const EditLinkPopover: FC<EditLinkPopoverProps> = ({ rect, initialUrl, onConfirm }) => {
|
||||
const { t } = useTranslation();
|
||||
const [url, setUrl] = useState(initialUrl);
|
||||
const confirm = () => onConfirm(url);
|
||||
return createPortal(
|
||||
<div
|
||||
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '8px' }}
|
||||
onMouseDown={e => e.stopPropagation()}
|
||||
>
|
||||
<Flex gap={8}>
|
||||
<Input
|
||||
size="small"
|
||||
className="rb:w-60!"
|
||||
placeholder={t('workflow.config.notes.enterLink')}
|
||||
value={url}
|
||||
onChange={e => setUrl(e.target.value)}
|
||||
onKeyDown={e => e.stopPropagation()}
|
||||
onPressEnter={confirm}
|
||||
autoFocus
|
||||
/>
|
||||
<Button size="small" type="primary" onClick={confirm}>{t('common.confirm')}</Button>
|
||||
</Flex>
|
||||
</div>,
|
||||
document.body
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,184 @@
|
||||
import { type FC, useState, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { LexicalComposer } from '@lexical/react/LexicalComposer';
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin';
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
|
||||
import { ListPlugin } from '@lexical/react/LexicalListPlugin';
|
||||
import { LinkPlugin } from '@lexical/react/LexicalLinkPlugin';
|
||||
import { ListNode, ListItemNode } from '@lexical/list';
|
||||
import { LinkNode } from '@lexical/link';
|
||||
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { useEffect, useRef } from 'react';
|
||||
import NoteFormatPlugin from './NoteFormatPlugin';
|
||||
import type { FormatState } from './NoteFormatPlugin';
|
||||
import { LinkPopover, EditLinkPopover } from './NoteLinkPopovers';
|
||||
|
||||
const theme = {
|
||||
paragraph: 'editor-paragraph',
|
||||
text: {
|
||||
bold: 'editor-text-bold',
|
||||
italic: 'editor-text-italic',
|
||||
strikethrough: 'note-text-strikethrough',
|
||||
},
|
||||
list: { ul: 'note-list-ul', listitem: 'note-list-item' },
|
||||
link: 'note-link',
|
||||
};
|
||||
|
||||
const NOTE_NODES = [ListNode, ListItemNode, LinkNode];
|
||||
|
||||
const NOTE_STYLES = `
|
||||
.editor-text-bold { font-weight: bold; }
|
||||
.editor-text-italic { font-style: italic; }
|
||||
.note-text-strikethrough { text-decoration: line-through; }
|
||||
.note-list-ul { list-style-type: disc; padding-left: 1.2em; margin: 0; }
|
||||
.note-list-item { margin: 2px 0; }
|
||||
.note-link { color: #2563eb; text-decoration: underline; cursor: pointer; }
|
||||
`;
|
||||
|
||||
const NoteInitPlugin: FC<{ value: string }> = ({ value }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const initialized = useRef(false);
|
||||
useEffect(() => {
|
||||
if (initialized.current || !value) return;
|
||||
initialized.current = true;
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
if (parsed?.root) {
|
||||
const state = editor.parseEditorState(JSON.stringify(parsed));
|
||||
editor.setEditorState(state);
|
||||
return;
|
||||
}
|
||||
} catch {}
|
||||
}, [editor, value]);
|
||||
return null;
|
||||
};
|
||||
|
||||
|
||||
interface NoteEditorProps {
|
||||
nodeId: string;
|
||||
value: string;
|
||||
fontSize?: number;
|
||||
onChange: (val: string) => void;
|
||||
onFormatChange?: (state: FormatState) => void;
|
||||
}
|
||||
|
||||
const NoteEditor: FC<NoteEditorProps> = ({ nodeId, value, fontSize = 12, onChange, onFormatChange }) => {
|
||||
const { t } = useTranslation();
|
||||
const [linkState, setLinkState] = useState<{ url: string; rect: DOMRect } | null>(null);
|
||||
const [editLinkRect, setEditLinkRect] = useState<{ url: string; rect: DOMRect } | null>(null);
|
||||
const removingLink = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!linkState) return;
|
||||
const handler = () => setLinkState(null);
|
||||
window.addEventListener('mousedown', handler);
|
||||
return () => window.removeEventListener('mousedown', handler);
|
||||
}, [!!linkState]);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: Event) => {
|
||||
const { id, url, rect: passedRect } = (e as CustomEvent).detail;
|
||||
if (id !== nodeId) return;
|
||||
if (passedRect) {
|
||||
setEditLinkRect({ url: url || '', rect: passedRect });
|
||||
return;
|
||||
}
|
||||
const sel = window.getSelection();
|
||||
if (sel && sel.rangeCount > 0) {
|
||||
const r = sel.getRangeAt(0).getBoundingClientRect();
|
||||
if (r.width > 0 || r.height > 0) { setEditLinkRect({ url: url || '', rect: r }); return; }
|
||||
}
|
||||
const linkEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
|
||||
const rect = linkEl?.getBoundingClientRect() ?? new DOMRect(window.innerWidth / 2, 200, 0, 0);
|
||||
setEditLinkRect({ url: url || '', rect });
|
||||
};
|
||||
window.addEventListener('note:edit-link', handler);
|
||||
return () => window.removeEventListener('note:edit-link', handler);
|
||||
}, [nodeId]);
|
||||
|
||||
const handleFormatChange = useCallback((state: FormatState) => {
|
||||
onFormatChange?.(state);
|
||||
if (state.linkUrl) {
|
||||
requestAnimationFrame(() => {
|
||||
if (removingLink.current) { removingLink.current = false; return; }
|
||||
const sel = window.getSelection();
|
||||
if (sel && sel.rangeCount > 0) {
|
||||
const rect = sel.getRangeAt(0).getBoundingClientRect();
|
||||
if (rect.width > 0 || rect.height > 0) {
|
||||
setLinkState({ url: state.linkUrl!, rect });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// fallback: find the link element in the correct editor
|
||||
const editorEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
|
||||
if (editorEl) {
|
||||
setLinkState({ url: state.linkUrl!, rect: editorEl.getBoundingClientRect() });
|
||||
}
|
||||
});
|
||||
} else {
|
||||
setLinkState(null);
|
||||
}
|
||||
}, [onFormatChange]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{NOTE_STYLES}</style>
|
||||
<LexicalComposer initialConfig={{ namespace: `note-${nodeId}`, theme, nodes: NOTE_NODES, onError: console.error }}>
|
||||
<div style={{ position: 'relative' }} data-note-id={nodeId}>
|
||||
<RichTextPlugin
|
||||
contentEditable={
|
||||
<ContentEditable
|
||||
style={{ minHeight: 60, outline: 'none', resize: 'none', fontSize: '12px', lineHeight: '18px', color: '#374151', overflow: 'auto', cursor: 'auto' }}
|
||||
/>
|
||||
}
|
||||
placeholder={
|
||||
<div style={{ position: 'absolute', top: 0, left: 0, color: '#9CA3AF', lineHeight: '18px', pointerEvents: 'none' }}>
|
||||
{t('workflow.config.notes.placeholder')}
|
||||
</div>
|
||||
}
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<ListPlugin />
|
||||
<LinkPlugin />
|
||||
<OnChangePlugin onChange={(editorState) => onChange(JSON.stringify(editorState.toJSON()))} />
|
||||
<NoteInitPlugin value={value} />
|
||||
<NoteFormatPlugin nodeId={nodeId} fontSize={fontSize} onFormatChange={handleFormatChange} />
|
||||
{editLinkRect && (
|
||||
<EditLinkPopover
|
||||
rect={editLinkRect.rect}
|
||||
initialUrl={editLinkRect.url}
|
||||
onConfirm={(url) => {
|
||||
removingLink.current = true;
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: url || null } }));
|
||||
setEditLinkRect(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{linkState && (
|
||||
<LinkPopover
|
||||
url={linkState.url}
|
||||
rect={linkState.rect}
|
||||
onEdit={() => {
|
||||
removingLink.current = true;
|
||||
const { rect, url } = linkState;
|
||||
setLinkState(null);
|
||||
setEditLinkRect({ url, rect });
|
||||
}}
|
||||
onRemove={() => {
|
||||
removingLink.current = true;
|
||||
setLinkState(null);
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: null } }));
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</LexicalComposer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteEditor;
|
||||
@@ -0,0 +1,163 @@
|
||||
import { type FC } from 'react';
|
||||
import { Flex, Dropdown, type MenuProps, Switch, Button, Divider } from 'antd';
|
||||
import { UnorderedListOutlined, BoldOutlined, ItalicOutlined, StrikethroughOutlined, LinkOutlined, DashOutlined } from '@ant-design/icons';
|
||||
import { Node } from '@antv/x6';
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { THEME_MAP } from '../../../constant';
|
||||
const FONT_SIZES = [
|
||||
{ label: '小', value: 12 },
|
||||
{ label: '中', value: 14 },
|
||||
{ label: '大', value: 16 },
|
||||
];
|
||||
|
||||
interface NoteNodeToolbarProps {
|
||||
node: Node;
|
||||
onFormat: (type: string, value?: unknown) => void;
|
||||
toolConfig: Record<string, number | boolean>;
|
||||
nodeId: string;
|
||||
}
|
||||
|
||||
const NoteNodeToolbar: FC<NoteNodeToolbarProps> = ({ node, onFormat, toolConfig, nodeId }) => {
|
||||
const data = node?.getData() || {};
|
||||
const { t } = useTranslation();
|
||||
|
||||
const colorItems: MenuProps['items'] = Object.entries(THEME_MAP).map(([key, theme]) => ({
|
||||
key,
|
||||
label: (
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
|
||||
style={{ background: theme.bg }}
|
||||
onClick={() => onFormat('color', key)}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
const fontSizeItems: MenuProps['items'] = FONT_SIZES.map(({ label, value }) => ({
|
||||
key: value,
|
||||
label: <span onClick={() => onFormat('fontSize', value)}>{label}</span>,
|
||||
}));
|
||||
|
||||
const currentFontSize = FONT_SIZES.find(f => f.value === toolConfig.fontSize)?.label ?? '小';
|
||||
|
||||
const handleClick: MenuProps['onClick'] = (e) => {
|
||||
switch (e.key) {
|
||||
case 'delete':
|
||||
node.remove()
|
||||
break;
|
||||
case 'copy':
|
||||
break;
|
||||
}
|
||||
}
|
||||
const handleChange = (type: string) => {
|
||||
let show_author = data.config.show_author.defaultValue
|
||||
if(type === 'showAuth'){
|
||||
show_author = !show_author
|
||||
}
|
||||
node.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
show_author: {
|
||||
...data.config.show_author,
|
||||
defaultValue: show_author
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
align="center"
|
||||
gap={8}
|
||||
className="rb:absolute rb:-top-11 rb:left-1/2 rb:-translate-x-1/2 rb:bg-white rb:z-10 rb:whitespace-nowrap rb:rounded-lg rb:py-1! rb:px-3!"
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
{/* Color picker */}
|
||||
<Dropdown menu={{ items: colorItems }} trigger={['click']}>
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
|
||||
style={{ background: THEME_MAP[data.bgColor]?.bg || THEME_MAP.blue.bg }}
|
||||
/>
|
||||
</Dropdown>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
{/* Font size */}
|
||||
<Dropdown menu={{ items: fontSizeItems }} trigger={['click']}>
|
||||
<Flex align="center" gap={4} className="rb:cursor-pointer rb:text-xs rb:text-gray-600 rb:select-none">
|
||||
<span className="rb:text-xs">Aa</span>
|
||||
<span className="rb:text-xs">{currentFontSize}</span>
|
||||
</Flex>
|
||||
</Dropdown>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
{/* Bold */}
|
||||
<Button
|
||||
type={toolConfig.bold ? 'primary' : 'text'}
|
||||
icon={<BoldOutlined />}
|
||||
onClick={() => onFormat('bold')}
|
||||
/>
|
||||
|
||||
{/* Italic */}
|
||||
<Button
|
||||
type={toolConfig.italic ? 'primary' : 'text'}
|
||||
icon={<ItalicOutlined />}
|
||||
onClick={() => onFormat('italic')}
|
||||
/>
|
||||
|
||||
{/* Strikethrough */}
|
||||
<Button
|
||||
type={toolConfig.strikethrough ? 'primary' : 'text'}
|
||||
icon={<StrikethroughOutlined />}
|
||||
onClick={() => onFormat('strikethrough')}
|
||||
/>
|
||||
|
||||
{/* Link */}
|
||||
<Button
|
||||
type={toolConfig.link ? 'primary' : 'text'}
|
||||
icon={<LinkOutlined />}
|
||||
onClick={() => {
|
||||
const sel = window.getSelection();
|
||||
const rect = sel && sel.rangeCount > 0 ? sel.getRangeAt(0).getBoundingClientRect() : undefined;
|
||||
window.dispatchEvent(new CustomEvent('note:edit-link', { detail: { id: nodeId, url: '', rect } }));
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* List */}
|
||||
<Button
|
||||
type={toolConfig.list ? 'primary' : 'text'}
|
||||
icon={<UnorderedListOutlined />}
|
||||
onClick={() => onFormat('list')}
|
||||
/>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
<Dropdown
|
||||
menu={{
|
||||
items: [
|
||||
// { key: 'copy', label: t('common.copy') },
|
||||
{
|
||||
key: 'showAuth',
|
||||
label: <Flex align="center" gap={24}>
|
||||
{t('workflow.config.notes.showAuth')}
|
||||
<Switch
|
||||
size="small"
|
||||
checked={data.config.show_author.defaultValue}
|
||||
onChange={() => handleChange('showAuth')}
|
||||
/>
|
||||
</Flex>
|
||||
},
|
||||
{ key: 'delete', label: <Flex>{t('common.delete')}</Flex> },
|
||||
],
|
||||
onClick: handleClick
|
||||
}}
|
||||
>
|
||||
<DashOutlined />
|
||||
</Dropdown>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteNodeToolbar;
|
||||
155
web/src/views/Workflow/components/Nodes/NoteNode/index.tsx
Normal file
155
web/src/views/Workflow/components/Nodes/NoteNode/index.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
import { useRef, useState } from 'react';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
import { Flex } from 'antd';
|
||||
|
||||
import NoteEditor from './NoteEditor';
|
||||
import NoteNodeToolbar from './NoteNodeToolbar';
|
||||
import { THEME_MAP } from '../../../constant'
|
||||
|
||||
const MIN_W = 240;
|
||||
const MIN_H = 120;
|
||||
|
||||
const NoteNode: ReactShapeConfig['component'] = ({ node }) => {
|
||||
const data = node?.getData() || {};
|
||||
const nodeId = node?.id || '';
|
||||
const startRef = useRef<{ x: number; y: number; w: number; h: number } | null>(null);
|
||||
const [toolConfig, setToolConfig] = useState({
|
||||
fontSize: 12,
|
||||
bold: false,
|
||||
italic: false,
|
||||
strikethrough: false,
|
||||
list: false,
|
||||
})
|
||||
|
||||
const handleFormat = (type: string, value?: unknown) => {
|
||||
console.log('handleFormat', type, value)
|
||||
if (type === 'color') {
|
||||
node?.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
theme: {
|
||||
...data.config.theme,
|
||||
defaultValue: value
|
||||
}
|
||||
}
|
||||
});
|
||||
} else if (type === 'fontSize') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'fontSize', value } }));
|
||||
} else if (type === 'link') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: value || null } }));
|
||||
} else if (type === 'list') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'list', value: !toolConfig.list } }));
|
||||
} else {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: type } }));
|
||||
}
|
||||
|
||||
setToolConfig(prev => ({ ...prev, [type]: value || !prev[type as unknown as keyof typeof toolConfig] }))
|
||||
};
|
||||
|
||||
const onResizeMouseDown = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
const size = node?.getSize();
|
||||
if (!size) return;
|
||||
startRef.current = { x: e.clientX, y: e.clientY, w: size.width, h: size.height };
|
||||
|
||||
const onMouseMove = (ev: MouseEvent) => {
|
||||
if (!startRef.current) return;
|
||||
const w = Math.max(MIN_W, startRef.current.w + ev.clientX - startRef.current.x);
|
||||
const h = Math.max(MIN_H, startRef.current.h + ev.clientY - startRef.current.y);
|
||||
|
||||
node?.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
width: {
|
||||
...data.config.width,
|
||||
defaultValue: w
|
||||
},
|
||||
height: {
|
||||
...data.config.height,
|
||||
defaultValue: h
|
||||
}
|
||||
}
|
||||
});
|
||||
node?.prop('size', { width: w, height: h });
|
||||
};
|
||||
const onMouseUp = () => {
|
||||
startRef.current = null;
|
||||
window.removeEventListener('mousemove', onMouseMove);
|
||||
window.removeEventListener('mouseup', onMouseUp);
|
||||
};
|
||||
window.addEventListener('mousemove', onMouseMove);
|
||||
window.addEventListener('mouseup', onMouseUp);
|
||||
};
|
||||
|
||||
const updateText = (value: string) => {
|
||||
node.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
text: {
|
||||
...data.config.text,
|
||||
defaultValue: value
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const theme = THEME_MAP[data.config?.theme?.defaultValue || 'blue'] || THEME_MAP['blue']
|
||||
|
||||
return (
|
||||
<div
|
||||
className="rb:relative rb:h-full rb:w-full rb:rounded-2xl rb:border"
|
||||
style={{
|
||||
background: theme.bg,
|
||||
borderColor: data.isSelected ? theme.outer : theme.border,
|
||||
}}
|
||||
>
|
||||
<div className="rb:h-4 rb:rounded-tl-2xl rb:rounded-tr-2xl"
|
||||
style={{
|
||||
background: theme.title
|
||||
}}
|
||||
></div>
|
||||
{data.isSelected && <NoteNodeToolbar node={node!} nodeId={nodeId} toolConfig={toolConfig} onFormat={handleFormat} />}
|
||||
|
||||
<div
|
||||
className="rb:w-full rb:h-[calc(100%-36px)] rb:p-2.5 rb:overflow-auto"
|
||||
onMouseDown={e => {
|
||||
e.stopPropagation()
|
||||
node?.setData({ ...node.getData(), isSelected: true })
|
||||
}}
|
||||
onWheel={e => e.stopPropagation()}
|
||||
>
|
||||
<NoteEditor
|
||||
nodeId={nodeId}
|
||||
value={data.config.text.defaultValue || ''}
|
||||
fontSize={toolConfig.fontSize}
|
||||
onChange={updateText}
|
||||
onFormatChange={(state) => setToolConfig(prev => ({ ...prev, ...state }))}
|
||||
/>
|
||||
</div>
|
||||
<Flex align="center" justify="space-between" className="rb:pl-2.5! rb:pr-1!">
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167]">
|
||||
{data.config.show_author.defaultValue
|
||||
? data.config.author.defaultValue
|
||||
: undefined
|
||||
}
|
||||
</div>
|
||||
|
||||
|
||||
{/* <div className="rb:size-4 rb:border-b-[4px] rb:border-r-[4px] rb:border-[#EBEBEB] rb:rounded-2xl"></div> */}
|
||||
<div
|
||||
onMouseDown={onResizeMouseDown}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M12 9.75V6H13.5V9.75C13.5 11.8211 11.8211 13.5 9.75 13.5H6V12H9.75C10.9926 12 12 10.9926 12 9.75Z" fill="black" fillOpacity="0.16"></path>
|
||||
</svg>
|
||||
</div>
|
||||
</Flex>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteNode;
|
||||
@@ -2,13 +2,14 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 15:06:18
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-11 12:07:20
|
||||
* @Last Modified time: 2026-03-09 13:41:19
|
||||
*/
|
||||
import LoopNode from './components/Nodes/LoopNode';
|
||||
import NormalNode from './components/Nodes/NormalNode';
|
||||
import ConditionNode from './components/Nodes/ConditionNode';
|
||||
import GroupStartNode from './components/Nodes/GroupStartNode';
|
||||
import AddNode from './components/Nodes/AddNode'
|
||||
import NoteNode from './components/Nodes/NoteNode';
|
||||
import type { PortMetadata, GroupMetadata } from '@antv/x6/lib/model/port';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
|
||||
@@ -525,6 +526,73 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
// ]
|
||||
// },
|
||||
];
|
||||
|
||||
export const THEME_MAP: Record<string, { outer: string; title: string; bg: string; border: string }> = {
|
||||
blue: {
|
||||
outer: '#2E90FA',
|
||||
title: '#D1E9FF',
|
||||
bg: '#EFF8FF',
|
||||
border: '#84CAFF',
|
||||
},
|
||||
cyan: {
|
||||
outer: '#06AED4',
|
||||
title: '#CFF9FE',
|
||||
bg: '#ECFDFF',
|
||||
border: '#67E3F9',
|
||||
},
|
||||
green: {
|
||||
outer: '#16B364',
|
||||
title: '#D3F8DF',
|
||||
bg: '#EDFCF2',
|
||||
border: '#73E2A3',
|
||||
},
|
||||
yellow: {
|
||||
outer: '#EAAA08',
|
||||
title: '#FEF7C3',
|
||||
bg: '#FEFBE8',
|
||||
border: '#FDE272',
|
||||
},
|
||||
pink: {
|
||||
outer: '#EE46BC',
|
||||
title: '#FCE7F6',
|
||||
bg: '#FDF2FA',
|
||||
border: '#FAA7E0',
|
||||
},
|
||||
violet: {
|
||||
outer: '#875BF7',
|
||||
title: '#ECE9FE',
|
||||
bg: '#F5F3FF',
|
||||
border: '#C3B5FD',
|
||||
},
|
||||
}
|
||||
|
||||
export const notesConfig = {
|
||||
type: "notes", icon: templateRenderingIcon,
|
||||
config: {
|
||||
text: {
|
||||
type: 'define',
|
||||
},
|
||||
theme: {
|
||||
type: 'define',
|
||||
defaultValue: 'blue',
|
||||
},
|
||||
width: {
|
||||
type: 'define',
|
||||
width: 240,
|
||||
},
|
||||
height: {
|
||||
type: 'define',
|
||||
height: 120,
|
||||
},
|
||||
author: {
|
||||
type: 'define',
|
||||
},
|
||||
show_author: {
|
||||
type: 'define',
|
||||
defaultValue: true
|
||||
}
|
||||
}
|
||||
}
|
||||
export const unknownNode = {
|
||||
type: 'unknown',
|
||||
icon: unknownIcon
|
||||
@@ -576,6 +644,12 @@ export const nodeRegisterLibrary: ReactShapeConfig[] = [
|
||||
height: 44,
|
||||
component: AddNode,
|
||||
},
|
||||
{
|
||||
shape: 'notes-node',
|
||||
width: nodeWidth,
|
||||
height: 120,
|
||||
component: NoteNode,
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
@@ -801,6 +875,11 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
||||
groups: {left: { position: 'left', markup: portMarkup, attrs: portAttrs }},
|
||||
items: [{ group: 'left' }],
|
||||
},
|
||||
},
|
||||
notes: {
|
||||
width: nodeWidth,
|
||||
height: 120,
|
||||
shape: 'notes-node',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
|
||||
import { register } from '@antv/x6-react-shape';
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode } from '../constant';
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode, notesConfig } from '../constant';
|
||||
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||
import { useUser } from '@/store/user';
|
||||
|
||||
/**
|
||||
* Props for useWorkflowGraph hook
|
||||
@@ -64,6 +65,8 @@ export interface UseWorkflowGraphReturn {
|
||||
chatVariables: ChatVariable[];
|
||||
/** Function to update chat variables */
|
||||
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
||||
|
||||
handleAddNotes: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -80,6 +83,7 @@ export const useWorkflowGraph = ({
|
||||
const { id } = useParams();
|
||||
const { message } = App.useApp();
|
||||
const { t } = useTranslation()
|
||||
const { user } = useUser();
|
||||
|
||||
// Refs
|
||||
const graphRef = useRef<Graph>();
|
||||
@@ -128,7 +132,7 @@ export const useWorkflowGraph = ({
|
||||
if (nodes.length) {
|
||||
const nodeList = nodes.map(node => {
|
||||
const { id, type, name, position, config = {} } = node
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, noteNode] }]
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, notesConfig] }]
|
||||
.flatMap(category => category.nodes)
|
||||
.find(n => n.type === type)
|
||||
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
|
||||
@@ -197,6 +201,13 @@ export const useWorkflowGraph = ({
|
||||
data: { ...node, ...nodeLibraryConfig},
|
||||
...position,
|
||||
}
|
||||
|
||||
if (type === 'notes') {
|
||||
const w = config.width;
|
||||
const h = config.height;
|
||||
if (w) nodeConfig.width = w as number;
|
||||
if (h) nodeConfig.height = h as number;
|
||||
}
|
||||
|
||||
// Generate ports dynamically for if-else node based on cases
|
||||
if (type === 'if-else' && config.cases && Array.isArray(config.cases)) {
|
||||
@@ -461,11 +472,12 @@ export const useWorkflowGraph = ({
|
||||
*/
|
||||
const nodeClick = ({ node }: { node: Node }) => {
|
||||
// Ignore add-node type node clicks
|
||||
if (node.getData()?.type === 'add-node' || node.getData().type === 'break' || node.getData().type === 'cycle-start') {
|
||||
const nodeData = node.getData()
|
||||
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') {
|
||||
setSelectedNode(null)
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const nodes = graphRef.current?.getNodes();
|
||||
|
||||
nodes?.forEach(vo => {
|
||||
@@ -478,10 +490,12 @@ export const useWorkflowGraph = ({
|
||||
}
|
||||
});
|
||||
node.setData({
|
||||
...node.getData(),
|
||||
...nodeData,
|
||||
isSelected: true,
|
||||
});
|
||||
setSelectedNode(node);
|
||||
if (nodeData.type !== 'notes') {
|
||||
setSelectedNode(node);
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Handle edge click event
|
||||
@@ -859,8 +873,31 @@ export const useWorkflowGraph = ({
|
||||
init();
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
|
||||
const handleNoteKeydown = (e: KeyboardEvent) => {
|
||||
if (!graphRef.current) return;
|
||||
const selectedNote = graphRef.current.getNodes().find(n => n.getData()?.isSelected && n.getData()?.type === 'notes');
|
||||
if (!selectedNote) return;
|
||||
const isMeta = e.ctrlKey || e.metaKey;
|
||||
if (e.key === 'Delete' || e.key === 'Backspace') {
|
||||
// Only delete node when editor is not focused on text
|
||||
const active = document.activeElement;
|
||||
if (active && (active as HTMLElement).isContentEditable) return;
|
||||
deleteEvent();
|
||||
} else if (isMeta && e.key === 'c') {
|
||||
copyEvent();
|
||||
} else if (isMeta && e.key === 'v') {
|
||||
parseEvent();
|
||||
} else if (isMeta && e.key === 'd') {
|
||||
e.preventDefault();
|
||||
deleteEvent();
|
||||
}
|
||||
};
|
||||
window.addEventListener('keydown', handleNoteKeydown);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
window.removeEventListener('keydown', handleNoteKeydown);
|
||||
graphRef.current?.dispose();
|
||||
};
|
||||
}, []);
|
||||
@@ -884,7 +921,7 @@ export const useWorkflowGraph = ({
|
||||
.flatMap(category => category.nodes)
|
||||
.find(n => n.type === dragData.type);
|
||||
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
|
||||
|
||||
|
||||
// Create clean node data, only keep necessary fields
|
||||
const cleanNodeData = {
|
||||
id: `${dragData.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
@@ -1103,6 +1140,32 @@ export const useWorkflowGraph = ({
|
||||
})
|
||||
}
|
||||
|
||||
const handleAddNotes = () => {
|
||||
if (!graphRef.current) return;
|
||||
const nodeConfig: NodeProperties = JSON.parse(JSON.stringify(notesConfig));
|
||||
nodeConfig.config = {
|
||||
...nodeConfig.config,
|
||||
author: { type: 'define', defaultValue: user?.username || '' },
|
||||
};
|
||||
const cleanNodeData = {
|
||||
id: `notes_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
name: t('workflow.notes'),
|
||||
...nodeConfig,
|
||||
};
|
||||
const container = graphRef.current.container;
|
||||
const nodeW = graphNodeLibrary.notes?.width || nodeWidth;
|
||||
const nodeH = graphNodeLibrary.notes?.height || 100;
|
||||
const rect = container.getBoundingClientRect();
|
||||
const center = graphRef.current.clientToLocal(rect.left + rect.width / 2, rect.top + rect.height / 2);
|
||||
graphRef.current.addNode({
|
||||
...(graphNodeLibrary.notes || graphNodeLibrary.default),
|
||||
x: center.x - nodeW / 2,
|
||||
y: center.y - nodeH / 2,
|
||||
id: cleanNodeData.id,
|
||||
data: { ...cleanNodeData },
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
config,
|
||||
setConfig,
|
||||
@@ -1120,6 +1183,7 @@ export const useWorkflowGraph = ({
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
setChatVariables,
|
||||
handleAddNotes
|
||||
};
|
||||
};
|
||||
|
||||
@@ -38,7 +38,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
setChatVariables,
|
||||
handleAddNotes
|
||||
} = useWorkflowGraph({ containerRef, miniMapRef });
|
||||
|
||||
const onDragOver = (event: React.DragEvent) => {
|
||||
@@ -95,6 +96,7 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
canRedo={canRedo}
|
||||
onUndo={onUndo}
|
||||
onRedo={onRedo}
|
||||
addNotes={handleAddNotes}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user