Merge pull request #568 from SuanmoSuanyangTechnology/release/v0.2.7
Release/v0.2.7
This commit is contained in:
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -63,9 +63,9 @@ celery_app.conf.update(
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=True,
|
||||
# # 时区
|
||||
# timezone='Asia/Shanghai',
|
||||
# enable_utc=False,
|
||||
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
@@ -113,6 +113,8 @@ celery_app.conf.update(
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
import io
|
||||
from typing import Optional, Annotated
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -1010,3 +1013,57 @@ def get_workspace_api_statistics(
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
release_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||
"""
|
||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||
encoded = quote(filename, safe=".")
|
||||
yaml_bytes = yaml_str.encode("utf-8")
|
||||
file_stream = io.BytesIO(yaml_bytes)
|
||||
file_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type="application/octet-stream; charset=utf-8",
|
||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||
"Content-Length": str(len(yaml_bytes))}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw = (await file.read()).decode("utf-8")
|
||||
dsl = yaml.safe_load(raw)
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
@@ -55,6 +55,12 @@ async def get_mcp_servers(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
if page * pagesize > 100:
|
||||
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
@@ -64,14 +70,16 @@ async def get_mcp_servers(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
@@ -140,14 +148,16 @@ async def get_operational_mcp_servers(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
@@ -198,14 +208,16 @@ async def get_mcp_server(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
@@ -226,7 +238,26 @@ async def create_mcp_market_config(
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
# 1. Validate token can access ModelScope MCP market
|
||||
if not create_data.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token is required to access ModelScope MCP market"
|
||||
)
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
# 2. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
@@ -262,10 +293,7 @@ async def get_mcp_market_config(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -295,10 +323,7 @@ async def get_mcp_market_config_by_mcp_market_id(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -324,12 +349,25 @@ async def update_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
# 2. Validate new token if provided
|
||||
if update_data.token is not None:
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
@@ -344,7 +382,7 @@ async def update_mcp_market_config(
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
@@ -381,10 +419,7 @@ async def delete_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
@@ -149,6 +150,21 @@ async def get_workspace_end_users(
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_interest_distribution_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
@@ -387,14 +403,15 @@ def get_current_user_rag_total_num(
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
获取当前宿主知识库中的chunk内容(分页)
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@@ -407,26 +424,18 @@ async def get_chunk_summary_tag(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
"summary": "用户摘要",
|
||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||
"personas": ["产品设计师", ...],
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -434,9 +443,8 @@ async def get_chunk_summary_tag(
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
@@ -447,24 +455,57 @@ async def get_chunk_insight(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
读取RAG洞察报告(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
"insight": "总体概述",
|
||||
"behavior_pattern": "行为模式",
|
||||
"key_findings": "关键发现",
|
||||
"growth_trajectory": "成长轨迹",
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
class GenerateRagProfileRequest(BaseModel):
|
||||
end_user_id: str = Field(..., description="宿主ID")
|
||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||
max_tags: int = Field(10, description="最大标签数量")
|
||||
|
||||
|
||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||
async def generate_rag_profile(
|
||||
body: GenerateRagProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||
每次请求都会重新生成,覆盖已有数据。
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||
|
||||
data = await memory_dashboard_service.generate_rag_profile(
|
||||
end_user_id=body.end_user_id,
|
||||
limit=body.limit,
|
||||
max_tags=body.max_tags,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
api_logger.info(f"RAG画像生产完成: {data}")
|
||||
return success(data=data, msg="RAG画像生产完成")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.models import User
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
@@ -97,7 +98,13 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
tool_id = service.create_tool(
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = await service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
@@ -107,6 +114,8 @@ async def create_tool(
|
||||
tags=request.tags
|
||||
)
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except BusinessException as e:
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
@@ -115,6 +114,7 @@ class Settings:
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
|
||||
@@ -33,6 +33,7 @@ class DialogExtractionResponse(BaseModel):
|
||||
|
||||
- is_related:对话与场景的相关性判定。
|
||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||
"""
|
||||
is_related: bool = Field(...)
|
||||
times: List[str] = Field(default_factory=list)
|
||||
@@ -41,6 +42,7 @@ class DialogExtractionResponse(BaseModel):
|
||||
contacts: List[str] = Field(default_factory=list)
|
||||
addresses: List[str] = Field(default_factory=list)
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
@@ -86,26 +88,17 @@ class SemanticPruner:
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
# 加载统一填充词库
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||
|
||||
# 判断是否为内置专门场景
|
||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
|
||||
# 自定义场景的本体类型列表(用于注入提示词)
|
||||
# 本体类型列表(用于注入提示词,所有场景均支持)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
|
||||
if self._is_builtin_scene:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
@@ -117,107 +110,27 @@ class SemanticPruner:
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
# _is_important_message 和 _importance_score 已移除:
|
||||
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
|
||||
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
判断顺序:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
@@ -229,25 +142,11 @@ class SemanticPruner:
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
@@ -432,14 +331,12 @@ class SemanticPruner:
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
is_builtin_scene=self._is_builtin_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"is_builtin_scene": self._is_builtin_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"language": self.language
|
||||
})
|
||||
@@ -504,62 +401,56 @@ class SemanticPruner:
|
||||
# 相关对话不剪枝
|
||||
return dialog
|
||||
|
||||
# 在不相关对话中,识别重要/不重要消息
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords
|
||||
)
|
||||
msgs = dialog.context.msgs
|
||||
imp_unrel_msgs: List[ConversationMessage] = []
|
||||
unimp_unrel_msgs: List[ConversationMessage] = []
|
||||
|
||||
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||
filler_ids: set = set()
|
||||
deletable: List[ConversationMessage] = []
|
||||
|
||||
for m in msgs:
|
||||
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
||||
imp_unrel_msgs.append(m)
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
pass # 保护消息:不加入任何桶,不会被删除
|
||||
elif self._is_filler_message(m):
|
||||
filler_ids.add(id(m))
|
||||
else:
|
||||
unimp_unrel_msgs.append(m)
|
||||
# 计算总删除目标数量
|
||||
deletable.append(m)
|
||||
|
||||
# 计算删除目标
|
||||
total_unrel = len(msgs)
|
||||
delete_target = int(total_unrel * proportion)
|
||||
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
|
||||
unimp_del_cap = len(unimp_unrel_msgs)
|
||||
max_capacity = max(0, len(msgs) - 1)
|
||||
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
|
||||
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
# 删除配额分配
|
||||
del_unimp = min(delete_target, unimp_del_cap)
|
||||
rem = delete_target - del_unimp
|
||||
del_imp = min(rem, imp_del_cap)
|
||||
|
||||
# 选取删除集合
|
||||
unimp_delete_ids = []
|
||||
imp_delete_ids = []
|
||||
if del_unimp > 0:
|
||||
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
|
||||
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
|
||||
if del_imp > 0:
|
||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
|
||||
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
|
||||
|
||||
# 统计实际删除数量(重要/不重要)
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept_msgs = []
|
||||
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
|
||||
# 优先删填充,再删其他可删消息(按出现顺序)
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in delete_targets:
|
||||
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
continue
|
||||
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
|
||||
actual_imp_deleted += 1
|
||||
continue
|
||||
kept_msgs.append(m)
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
if id(m) in filler_ids:
|
||||
to_delete_ids.add(id(m))
|
||||
for m in deletable:
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
to_delete_ids.add(id(m))
|
||||
|
||||
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
deleted_total = actual_unimp_deleted + actual_imp_deleted
|
||||
deleted_total = len(msgs) - len(kept_msgs)
|
||||
protected_count = len(msgs) - len(filler_ids) - len(deletable)
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
|
||||
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
|
||||
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
dialog.context = ConversationContext(msgs=kept_msgs)
|
||||
@@ -594,51 +485,64 @@ class SemanticPruner:
|
||||
result: List[DialogData] = []
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
|
||||
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
|
||||
async with semaphore:
|
||||
try:
|
||||
return await self._extract_dialog_important(dd.content)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
|
||||
return DialogExtractionResponse(is_related=True)
|
||||
|
||||
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
|
||||
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
|
||||
|
||||
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
|
||||
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
|
||||
)
|
||||
|
||||
# 判断是否需要详细日志
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
|
||||
if extraction.preserve_keywords:
|
||||
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
|
||||
|
||||
# 消息级分类:LLM保护 / 填充 / 其他可删
|
||||
llm_protected_msgs = [] # LLM 保护消息(preserve_tokens 命中):绝对不可删除
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
deletable_msgs = [] # 其余消息(按比例删除)
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护(LLM,不可删)")
|
||||
elif self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
deletable_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 可删")
|
||||
|
||||
# important_msgs 仅用于日志统计
|
||||
important_msgs = llm_protected_msgs
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
@@ -649,37 +553,23 @@ class SemanticPruner:
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
deleted_details = []
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
for idx, msg in filler_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,按出现顺序删可删消息
|
||||
for idx, msg in deletable_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
@@ -707,7 +597,7 @@ class SemanticPruner:
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,66 +1,25 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
场景特定配置 - 统一填充词库
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
|
||||
本模块仅保留统一填充词库(filler_phrases),用于识别无意义寒暄/表情/口头禅。
|
||||
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
"""场景特定的识别模式(仅保留填充词库)"""
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
"""场景配置注册表 - 所有场景共用统一填充词库"""
|
||||
|
||||
BASE_FILLERS: Set[str] = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
"是的", "对", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 服务类套话
|
||||
"请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "欢迎下次咨询",
|
||||
# 外呼套话
|
||||
"喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
|
||||
# 教育场景套话
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
@@ -81,246 +59,8 @@ class SceneConfigRegistry:
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
def get_config(cls, scene: str = "") -> ScenePatterns:
|
||||
"""所有场景统一返回同一份填充词库"""
|
||||
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||
输入:pruning_scene, ontology_classes, dialog_text, language
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -9,64 +9,71 @@
|
||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||
- addresses: [string],地址/地点相关文本
|
||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
|
||||
|
||||
要求:
|
||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{# ── 内置场景的固定说明 ── #}
|
||||
{% set builtin_scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{# ── 确定最终使用的场景说明 ── #}
|
||||
{% if is_builtin_scene %}
|
||||
{# 内置专门场景:使用固定说明 #}
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% else %}
|
||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{# ── 确定场景说明 ── #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{# 无本体类型时退化为通用说明 #}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务:
|
||||
1. 判断对话是否与指定场景相关;
|
||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||
|
||||
场景说明:{{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
{% if custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
【必须保留的内容(不可删除)】
|
||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
|
||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||
- 地址信息:地点、地址、位置 → addresses 字段
|
||||
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
|
||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
|
||||
|
||||
【可以删除的内容】
|
||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
|
||||
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
|
||||
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
|
||||
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
|
||||
|
||||
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
|
||||
例如:
|
||||
- "我好开心呀" → 包含情绪(开心),必须保留,preserve_keywords 中加入"开心"
|
||||
- "好喜欢打羽毛球呀" → 包含兴趣爱好(喜欢打羽毛球),必须保留,preserve_keywords 中加入"喜欢打羽毛球"
|
||||
- "我好难过" → 包含情绪(难过),必须保留,preserve_keywords 中加入"难过"
|
||||
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪,必须保留,preserve_keywords 中加入"开心"
|
||||
|
||||
---
|
||||
对话全文:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -80,15 +87,46 @@
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||
1. Determine whether the dialogue is relevant to the specified scene;
|
||||
2. Extract all important information fragments that must be preserved.
|
||||
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
{% if custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
[MUST PRESERVE (cannot be deleted)]
|
||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||
- Time information: dates, time points, durations, expiry dates → times field
|
||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
|
||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||
- Address information: locations, addresses, places → addresses field
|
||||
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
|
||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
|
||||
|
||||
[CAN BE DELETED]
|
||||
The following types of content are low-value and can be removed during pruning:
|
||||
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
|
||||
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
|
||||
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
|
||||
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
|
||||
|
||||
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
|
||||
Examples:
|
||||
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
|
||||
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
|
||||
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
|
||||
|
||||
---
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -102,6 +140,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
|
||||
|
||||
from .chunk_summary import generate_chunk_summary
|
||||
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
|
||||
from .chunk_insight import generate_chunk_insight
|
||||
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
|
||||
|
||||
__all__ = [
|
||||
"generate_chunk_summary",
|
||||
"extract_chunk_tags",
|
||||
"extract_chunk_persona",
|
||||
"generate_chunk_insight",
|
||||
"generate_chunk_insight_sections",
|
||||
]
|
||||
|
||||
@@ -1,213 +1,207 @@
|
||||
"""
|
||||
Generate insights from RAG chunks.
|
||||
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
|
||||
|
||||
This module provides functionality to analyze chunk content and generate insights using LLM.
|
||||
The memory_insight.jinja2 template produces a four-section report:
|
||||
【总体概述】 → memory_insight
|
||||
【行为模式】 → behavior_pattern
|
||||
【关键发现】 → key_findings
|
||||
【成长轨迹】 → growth_trajectory
|
||||
|
||||
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
|
||||
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
"""Pydantic model for chunk insight."""
|
||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
||||
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
|
||||
|
||||
async def _classify_domain(chunk: str, llm_client) -> str:
|
||||
"""Classify a single chunk into a domain category."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DomainClassification(BaseModel):
|
||||
"""Pydantic model for domain classification."""
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="内容所属的领域分类",
|
||||
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
|
||||
)
|
||||
class _Domain(BaseModel):
|
||||
domain: str = Field(..., description="领域分类")
|
||||
|
||||
|
||||
async def classify_chunk_domain(chunk: str) -> str:
|
||||
"""
|
||||
Classify a chunk into a specific domain.
|
||||
|
||||
Args:
|
||||
chunk: Chunk content string
|
||||
|
||||
Returns:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
|
||||
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
|
||||
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
|
||||
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
|
||||
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
|
||||
- 健康:医疗、养生、运动、心理、保健、疾病等
|
||||
- 其他:无法归入以上类别的内容
|
||||
|
||||
文本内容: {chunk[:500]}...
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
classification = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DomainClassification
|
||||
prompt = (
|
||||
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
|
||||
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
|
||||
)
|
||||
|
||||
return classification.domain if classification else "其他"
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分类chunk领域失败: {str(e)}")
|
||||
result = await llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=_Domain,
|
||||
)
|
||||
return result.domain if result else "其他"
|
||||
except Exception:
|
||||
return "其他"
|
||||
|
||||
|
||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
||||
async def _build_insight_inputs(
|
||||
chunks: List[str],
|
||||
max_chunks: int,
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Analyze the domain distribution of chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary of domain -> percentage
|
||||
Derive domain_distribution, active_periods, social_connections strings
|
||||
to feed into the memory_insight.jinja2 template.
|
||||
"""
|
||||
if not chunks:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# 限制分析的chunk数量
|
||||
chunks_to_analyze = chunks[:max_chunks]
|
||||
|
||||
# 为每个chunk分类
|
||||
domain_counts = Counter()
|
||||
for chunk in chunks_to_analyze:
|
||||
domain = await classify_chunk_domain(chunk)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
# 计算百分比
|
||||
total = sum(domain_counts.values())
|
||||
domain_distribution = {
|
||||
domain: count / total
|
||||
for domain, count in domain_counts.items()
|
||||
}
|
||||
|
||||
# 按百分比降序排序
|
||||
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分析领域分布失败: {str(e)}")
|
||||
return {}
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
chunks_sample = chunks[:max_chunks]
|
||||
|
||||
# Domain distribution
|
||||
domain_counts: Counter = Counter()
|
||||
for chunk in chunks_sample:
|
||||
domain = await _classify_domain(chunk, llm_client)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
total = sum(domain_counts.values()) or 1
|
||||
domain_distribution = ", ".join(
|
||||
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
|
||||
)
|
||||
|
||||
return {
|
||||
"domain_distribution": domain_distribution,
|
||||
"active_periods": None, # RAG模式暂无时间维度数据
|
||||
"social_connections": None, # RAG模式暂无社交关联数据
|
||||
}
|
||||
|
||||
|
||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
||||
# ── Section parser ────────────────────────────────────────────────────────────
|
||||
|
||||
_ZH_SECTIONS = {
|
||||
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
|
||||
"key_findings": r"【关键发现】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
_EN_SECTIONS = {
|
||||
"memory_insight": r"【Overview】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
|
||||
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
|
||||
"""Extract the four sections from the LLM output."""
|
||||
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
|
||||
result = {}
|
||||
for key, pattern in patterns.items():
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
result[key] = match.group(1).strip() if match else ""
|
||||
return result
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_insight(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate insights from the given chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
A comprehensive insight report
|
||||
Generate a memory insight report from RAG chunks.
|
||||
|
||||
Returns the full raw report text (suitable for end_user.memory_insight).
|
||||
Use generate_chunk_insight_sections() when you need all four dimensions.
|
||||
"""
|
||||
sections = await generate_chunk_insight_sections(
|
||||
chunks=chunks,
|
||||
max_chunks=max_chunks,
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
|
||||
|
||||
|
||||
async def generate_chunk_insight_sections(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a four-section memory insight report from RAG chunks.
|
||||
|
||||
Returns a dict with keys:
|
||||
memory_insight, behavior_pattern, key_findings, growth_trajectory
|
||||
(plus '_raw' containing the full LLM output for debugging)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成洞察")
|
||||
return "暂无足够数据生成洞察报告"
|
||||
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "暂无足够数据生成洞察报告"
|
||||
return empty
|
||||
|
||||
try:
|
||||
# 1. 分析领域分布
|
||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
||||
|
||||
# 2. 统计基本信息
|
||||
total_chunks = len(chunks)
|
||||
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
|
||||
|
||||
# 3. 构建洞察prompt
|
||||
prompt_parts = []
|
||||
|
||||
if domain_dist:
|
||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
||||
prompt_parts.append(f"- 内容领域分布: {top_domains}")
|
||||
|
||||
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}字")
|
||||
|
||||
# 添加部分chunk内容作为参考
|
||||
sample_chunks = chunks[:5]
|
||||
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
|
||||
prompt_parts.append(f"\n内容示例:\n{sample_content}")
|
||||
|
||||
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
|
||||
重要规则:
|
||||
1. 报告需要将所有要点流畅地串联成一个段落
|
||||
2. 语言风格要专业、客观,同时易于理解
|
||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
||||
4. 基于提供的数据和示例内容进行分析,不要编造信息
|
||||
5. 重点关注内容的主题、特点和价值
|
||||
6. 报告长度控制在150-200字
|
||||
# Build template inputs from chunk analysis
|
||||
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
|
||||
|
||||
例如,如果输入是:
|
||||
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
|
||||
- 内容规模: 共50个知识片段,平均长度320字
|
||||
内容示例: [示例内容...]
|
||||
rendered_prompt = await render_memory_insight_prompt(
|
||||
domain_distribution=inputs["domain_distribution"],
|
||||
active_periods=inputs["active_periods"],
|
||||
social_connections=inputs["social_connections"],
|
||||
language=language,
|
||||
)
|
||||
|
||||
你的输出应该类似:
|
||||
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段,平均每个片段约320字,内容详实。从示例来看,内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
|
||||
"""
|
||||
|
||||
user_prompt = "\n".join(prompt_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = _get_llm_client()
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk洞察,分析了 {min(len(chunks), max_chunks)} 个片段")
|
||||
|
||||
return insight
|
||||
|
||||
raw_text = response.content.strip() if response and response.content else ""
|
||||
|
||||
sections = _parse_sections(raw_text, language=language)
|
||||
sections["_raw"] = raw_text
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk洞察四维度,分析了 {min(len(chunks), max_chunks)} 个片段"
|
||||
)
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk洞察失败: {str(e)}")
|
||||
return "洞察生成失败"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"Python是一种高级编程语言,以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
|
||||
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
|
||||
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
|
||||
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
|
||||
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
|
||||
]
|
||||
|
||||
print("开始生成chunk洞察...")
|
||||
insight = asyncio.run(generate_chunk_insight(test_chunks))
|
||||
print(f"\n生成的洞察:\n{insight}")
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "洞察生成失败"
|
||||
return empty
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
Generate summary for RAG chunks.
|
||||
|
||||
This module provides functionality to summarize chunk content using LLM.
|
||||
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
"""Pydantic model for chunk summary."""
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
# ── Schema ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MemorySummaryStatement(BaseModel):
|
||||
"""Single labelled statement extracted by memory_summary.jinja2."""
|
||||
statement: str = Field(..., description="提取的陈述内容")
|
||||
label: Optional[str] = Field(None, description="陈述标签")
|
||||
|
||||
|
||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
||||
class MemorySummaryResponse(BaseModel):
|
||||
"""
|
||||
Generate a summary for the given chunks.
|
||||
|
||||
Structured output expected from memory_summary.jinja2.
|
||||
The template asks for a JSON array of labelled statements;
|
||||
we wrap it in an object so response_structured can parse it.
|
||||
"""
|
||||
statements: List[MemorySummaryStatement] = Field(
|
||||
default_factory=list,
|
||||
description="从chunk中提取的陈述列表"
|
||||
)
|
||||
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
|
||||
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
# ── Core function ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_summary(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 10,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
|
||||
|
||||
The template extracts labelled statements from the chunks; we then join them
|
||||
into a coherent summary string that can be stored in end_user.user_summary.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to process (default: 10)
|
||||
|
||||
max_chunks: Maximum number of chunks to process
|
||||
end_user_id: Optional end-user ID for model selection
|
||||
language: Output language ("zh" or "en")
|
||||
|
||||
Returns:
|
||||
A concise summary of the chunks
|
||||
Summary string (joined statements or fallback text)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成摘要")
|
||||
return "暂无内容"
|
||||
|
||||
|
||||
try:
|
||||
# 限制处理的chunk数量,避免token过多
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
|
||||
chunks_to_process = chunks[:max_chunks]
|
||||
|
||||
# 合并chunk内容
|
||||
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
|
||||
|
||||
# 构建prompt
|
||||
system_prompt = (
|
||||
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
|
||||
"- 摘要长度控制在100-150字;\n"
|
||||
"- 提取核心信息和关键要点;\n"
|
||||
"- 使用客观、清晰的语言;\n"
|
||||
"- 避免冗余和重复;\n"
|
||||
"- 如果内容涉及多个主题,按重要性排序呈现。"
|
||||
chunk_texts = "\n\n".join(
|
||||
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
|
||||
)
|
||||
|
||||
json_schema = MemorySummaryResponse.model_json_schema()
|
||||
|
||||
rendered_prompt = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# Try structured output; fall back to plain chat only for LLMClientException
|
||||
# (indicates the model/provider doesn't support structured output).
|
||||
# All other exceptions are re-raised so config/schema errors stay visible.
|
||||
try:
|
||||
response: MemorySummaryResponse = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
if response.summary:
|
||||
summary = response.summary.strip()
|
||||
elif response.statements:
|
||||
summary = ";".join(s.statement for s in response.statements)
|
||||
else:
|
||||
summary = "暂无内容"
|
||||
except Exception as e:
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
if isinstance(e, LLMClientException):
|
||||
business_logger.warning(
|
||||
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
|
||||
)
|
||||
raw = await llm_client.chat(messages=messages)
|
||||
summary = raw.content.strip() if raw and raw.content else "暂无内容"
|
||||
else:
|
||||
business_logger.error(f"生成摘要时发生非预期异常: {e}")
|
||||
raise
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段"
|
||||
)
|
||||
|
||||
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk摘要失败: {str(e)}")
|
||||
return "摘要生成失败"
|
||||
|
||||
|
||||
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
|
||||
"""
|
||||
Generate summaries for multiple chunk lists in batch.
|
||||
|
||||
Args:
|
||||
chunks_list: List of chunk lists
|
||||
|
||||
Returns:
|
||||
List of summaries
|
||||
"""
|
||||
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
|
||||
"第二段内容介绍了深度学习的应用场景和发展历史。",
|
||||
"第三段讨论了自然语言处理技术的最新进展。"
|
||||
]
|
||||
|
||||
print("开始生成chunk摘要...")
|
||||
summary = asyncio.run(generate_chunk_summary(test_chunks))
|
||||
print(f"\n生成的摘要:\n{summary}")
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||
|
||||
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Extract meaningful tags from the given chunks.
|
||||
|
||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||
|
||||
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract persona (人物形象) from the given chunks.
|
||||
|
||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -85,6 +85,7 @@ class StorageFactory:
|
||||
access_key_id=settings.S3_ACCESS_KEY_ID,
|
||||
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
|
||||
bucket_name=settings.S3_BUCKET_NAME,
|
||||
endpoint_url=settings.S3_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -35,6 +35,19 @@ class S3Storage(StorageBackend):
|
||||
bucket_name: The name of the S3 bucket.
|
||||
region: The AWS region.
|
||||
"""
|
||||
AMAZON_S3_ENDPOINT_MAP = {
|
||||
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
|
||||
"us-east-2": "https://s3.us-east-2.amazonaws.com",
|
||||
"us-west-1": "https://s3.us-west-1.amazonaws.com",
|
||||
"us-west-2": "https://s3.us-west-2.amazonaws.com",
|
||||
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
|
||||
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
|
||||
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
|
||||
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
|
||||
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
|
||||
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
|
||||
# 可根据需要扩展其他地域
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,6 +55,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
endpoint_url: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the S3Storage backend.
|
||||
@@ -51,6 +65,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: The AWS access key ID.
|
||||
secret_access_key: The AWS secret access key.
|
||||
bucket_name: The name of the S3 bucket.
|
||||
endpoint_url: The complete URL to use for the constructed client.
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +84,19 @@ class S3Storage(StorageBackend):
|
||||
self.region = region
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
if not endpoint_url:
|
||||
# 优先匹配内置映射表(解决特殊地域)
|
||||
if region in self.AMAZON_S3_ENDPOINT_MAP:
|
||||
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
|
||||
# 兜底:通用拼接(适配未配置的新地域)
|
||||
else:
|
||||
endpoint_url = f"https://s3.{region}.amazonaws.com"
|
||||
|
||||
try:
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
region_name=region,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
)
|
||||
|
||||
@@ -53,6 +53,7 @@ class SimpleMCPClient:
|
||||
else:
|
||||
await self._connect_http()
|
||||
except Exception as e:
|
||||
await self.disconnect()
|
||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||
raise MCPConnectionError(f"连接失败: {e}")
|
||||
|
||||
|
||||
@@ -8,34 +8,60 @@ from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||
from app.core.workflow.adapters.errors import (
|
||||
UnsupportVariableType,
|
||||
UnknowModelWarning,
|
||||
ExceptionDefineition,
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
EndNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
from app.core.workflow.nodes.end import EndNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||
HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
AssignmentOperator,
|
||||
HttpAuthType,
|
||||
HttpContentType,
|
||||
HttpErrorHandle,
|
||||
NodeType
|
||||
)
|
||||
from app.core.workflow.nodes.http_request.config import (
|
||||
HttpAuthConfig,
|
||||
HttpContentTypeConfig,
|
||||
HttpFormData,
|
||||
HttpTimeOutConfig,
|
||||
HttpRetryConfig,
|
||||
HttpErrorDefaultTamplete,
|
||||
HttpErrorHandleConfig
|
||||
)
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
|
||||
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
"start": self.convert_start_node_config,
|
||||
"llm": self.convert_llm_node_config,
|
||||
"answer": self.convert_end_node_config,
|
||||
"if-else": self.convert_if_else_node_config,
|
||||
"loop": self.convert_loop_node_config,
|
||||
"iteration": self.convert_iteration_node_config,
|
||||
"assigner": self.convert_assigner_node_config,
|
||||
"code": self.convert_code_node_config,
|
||||
"http-request": self.convert_http_node_config,
|
||||
"template-transform": self.convert_jinja_render_node_config,
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
||||
"tool": self.convert_tool_node_config,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
NodeType.START: self.convert_start_node_config,
|
||||
NodeType.LLM: self.convert_llm_node_config,
|
||||
NodeType.END: self.convert_end_node_config,
|
||||
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||
NodeType.LOOP: self.convert_loop_node_config,
|
||||
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||
NodeType.ASSIGNER: self.convert_assigner_node_config,
|
||||
NodeType.CODE: self.convert_code_node_config,
|
||||
NodeType.HTTP_REQUEST: self.convert_http_node_config,
|
||||
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
|
||||
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
|
||||
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
@@ -732,3 +758,16 @@ class DifyConverter(BaseConverter):
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def convert_notes_config(node: dict):
|
||||
node_data = node["data"]
|
||||
result = NoteNodeConfig.model_construct(
|
||||
author=node_data.get("author", ""),
|
||||
text=node_data.get("text", ""),
|
||||
width=node_data.get("width", 80),
|
||||
height=node_data.get("height", 80),
|
||||
theme=node_data.get("theme", "blue"),
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
@@ -50,7 +50,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
DifyConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
def map_node_type(self, platform_node_type) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@property
|
||||
@@ -84,7 +84,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app",{}).get("mode") == "workflow":
|
||||
if self.config.get("app", {}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
@@ -163,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_data = node["data"]
|
||||
try:
|
||||
node_type = self.map_node_type(node_data["type"])
|
||||
return NodeDefinition(
|
||||
id=node["id"],
|
||||
type=self.map_node_type(node_data["type"]),
|
||||
type=node_type,
|
||||
name=node_data.get("title") or "notes",
|
||||
cycle=node.get("parentId"),
|
||||
description=None,
|
||||
config=self._convert_node_config(node),
|
||||
config=self._convert_node_config(node_type, node),
|
||||
position={
|
||||
"x": node["position"]["x"],
|
||||
"y": node["position"]["y"]
|
||||
@@ -183,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
except Exception as e:
|
||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||
|
||||
def _convert_node_config(self, node: dict):
|
||||
node_data = node["data"]
|
||||
node_type = node_data["type"]
|
||||
def _convert_node_config(self, node_type: NodeType, node: dict):
|
||||
try:
|
||||
node_data = node["data"]
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"node type {node_type if node_type else 'notes'} is unsupported",
|
||||
detail=f"node type {node_data.get('type')} is unsupported",
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
@@ -214,7 +214,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if source in self.branch_node_cache:
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "false":
|
||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
|
||||
else:
|
||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||
if source in self.error_branch_node_cache:
|
||||
@@ -257,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||
return ExecutionConfig()
|
||||
|
||||
|
||||
|
||||
@@ -4,65 +4,145 @@
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
MemoryBearConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
return self.config.get("workflow").get("nodes") or []
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
return self.config.get("workflow").get("edges") or []
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
return self.config.get("workflow").get("variables") or []
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
support_node_types=list(VALID_NODE_TYPES)
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
def map_node_type(self, platform_node_type: str) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
def _valid_node(node: dict[str, Any]) -> bool:
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
if not isinstance(node.get("config"), dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
if not self._valid_node(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_id = node.get("id")
|
||||
node_name = node.get("name")
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
return None
|
||||
|
||||
config = node.get("config") or {}
|
||||
converter = self.get_node_convert(node_type)
|
||||
converter(node_id, node_name, config) # validates and appends errors if invalid
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=f"convert node error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
for node in self.origin_nodes:
|
||||
converted = self._convert_node(node)
|
||||
if converted:
|
||||
self.nodes.append(converted)
|
||||
|
||||
valid_node_ids = {n.id for n in self.nodes}
|
||||
|
||||
for edge in self.origin_edges:
|
||||
converted = self._convert_edge(edge, valid_node_ids)
|
||||
if converted:
|
||||
self.edges.append(converted)
|
||||
|
||||
for variable in self.origin_variables:
|
||||
converted = self._convert_variable(variable)
|
||||
if converted:
|
||||
self.conv_variables.append(converted)
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
EndNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AgentNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
VariableAggregatorNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
ToolNodeConfig,
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class MemoryBearConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
|
||||
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
|
||||
NodeType.START: StartNodeConfig,
|
||||
NodeType.END: EndNodeConfig,
|
||||
NodeType.ANSWER: EndNodeConfig,
|
||||
NodeType.LLM: LLMNodeConfig,
|
||||
NodeType.AGENT: AgentNodeConfig,
|
||||
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
|
||||
NodeType.ASSIGNER: AssignerNodeConfig,
|
||||
NodeType.CODE: CodeNodeConfig,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
|
||||
NodeType.JINJARENDER: JinjaRenderNodeConfig,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
|
||||
NodeType.LOOP: LoopNodeConfig,
|
||||
NodeType.ITERATION: IterationNodeConfig,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
|
||||
NodeType.TOOL: ToolNodeConfig,
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
return []
|
||||
|
||||
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
def get_node_convert(self, node_type: NodeType):
|
||||
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
|
||||
if not config_cls:
|
||||
return lambda node_id, node_name, config: config
|
||||
|
||||
def validate(node_id: str, node_name: str, config: dict):
|
||||
self.config_validate(node_id, node_name, config_cls, config)
|
||||
return config
|
||||
|
||||
return validate
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
@@ -643,15 +644,18 @@ class BaseNode(ABC):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
origin_file_type=content.origin_file_type,
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -47,5 +48,6 @@ __all__ = [
|
||||
"ToolNodeConfig",
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig"
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
]
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
|
||||
|
||||
class HttpAuthConfig(BaseModel):
|
||||
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
|
||||
description="Http response headers"
|
||||
)
|
||||
|
||||
files: list[FileObject] = Field(
|
||||
default_factory=list,
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="SUCCESS",
|
||||
description="HTTP response body",
|
||||
|
||||
@@ -1,24 +1,146 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
import imghdr
|
||||
from email.message import Message
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
import magic
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class HttpResponse:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
self._is_file: bool | None = None
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
if self._is_file is not None:
|
||||
return self._is_file
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
|
||||
parsed_content_disposition = self.content_disposition
|
||||
if parsed_content_disposition:
|
||||
disp_type = parsed_content_disposition.get_content_disposition()
|
||||
filename = parsed_content_disposition.get_filename()
|
||||
if disp_type == "attachment" or filename:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
if content_type.startswith("application/"):
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
|
||||
):
|
||||
self._is_file = False
|
||||
return False
|
||||
try:
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
return self._is_file
|
||||
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
return self._is_file
|
||||
|
||||
@property
|
||||
def is_image(self):
|
||||
if self.is_file:
|
||||
kind = imghdr.what(None, h=self.response.content)
|
||||
return kind is not None
|
||||
return False
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return str(self.response.url)
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
if self.is_file:
|
||||
return f"{'!' if self.is_image else ''}[file]({self.url})"
|
||||
return self.response.text
|
||||
|
||||
@staticmethod
|
||||
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
|
||||
mime = magic.from_buffer(file_bytes, mime=True)
|
||||
|
||||
if mime.startswith("image"):
|
||||
return FileType.IMAGE, mime
|
||||
elif mime.startswith("video"):
|
||||
return FileType.VIDEO, mime
|
||||
elif mime.startswith("audio"):
|
||||
return FileType.AUDIO, mime
|
||||
elif mime in ["application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"text/plain"]:
|
||||
return FileType.DOCUMENT, mime
|
||||
return None, None
|
||||
|
||||
@property
|
||||
def files(self) -> list[FileObject]:
|
||||
file_type, mime_type = self.get_file_type(self.response.content)
|
||||
origin_file_type = mime_to_file_type(mime_type)
|
||||
if self.is_file and file_type and origin_file_type:
|
||||
file_obj = FileObject(
|
||||
type=file_type,
|
||||
url=self.url,
|
||||
transfer_method=TransferMethod.REMOTE_URL.value,
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
is_file=True
|
||||
)
|
||||
file_obj.set_content(self.response.content)
|
||||
return [
|
||||
file_obj
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
HTTP Request Workflow Node.
|
||||
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"files": VariableType.ARRAY_FILE,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files
|
||||
).model_dump()
|
||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||
logger.error(f"HTTP request node exception: {e}")
|
||||
|
||||
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class NoteNodeConfig(BaseNodeConfig):
|
||||
author: str = Field(default="", description="author")
|
||||
text: str = Field(default="", description="note content")
|
||||
width: int = Field(default=80)
|
||||
height: int = Field(default=80)
|
||||
theme: str = Field(default="blue")
|
||||
show_author: bool = Field(default=True)
|
||||
56
api/app/core/workflow/utils/file_processer.py
Normal file
56
api/app/core/workflow/utils/file_processer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
'text/x-markdown': 'document/x-markdown',
|
||||
|
||||
'application/pdf': 'document/pdf',
|
||||
|
||||
'application/msword': 'document/doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||
|
||||
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||
}
|
||||
ALLOWED_FILE_TYPES = [
|
||||
'text/plain',
|
||||
'text/markdown',
|
||||
'text/x-markdown',
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'image/jpg',
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
'image/svg+xml',
|
||||
'video/mp4',
|
||||
'video/quicktime',
|
||||
'video/x-msvideo',
|
||||
'video/x-matroska',
|
||||
'video/webm',
|
||||
'video/x-flv',
|
||||
'video/x-ms-wmv',
|
||||
'audio/mpeg',
|
||||
'audio/wav',
|
||||
'audio/ogg',
|
||||
'audio/aac',
|
||||
'audio/flac',
|
||||
'audio/mp4',
|
||||
'audio/x-ms-wma',
|
||||
'audio/x-m4a',
|
||||
]
|
||||
|
||||
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
@@ -114,9 +114,16 @@ class FileObject(BaseModel):
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
_byte_content: bytes | None = None
|
||||
|
||||
def get_content(self):
|
||||
return self._byte_content
|
||||
|
||||
def set_content(self, byte_content):
|
||||
self._byte_content = byte_content
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
@@ -51,6 +51,12 @@ class EndUser(Base):
|
||||
growth_trajectory = Column(Text, nullable=True, comment="成长轨迹")
|
||||
memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间")
|
||||
|
||||
# RAG存储模式专用字段 - RAG Storage Mode Fields
|
||||
# storage_type = Column(String, nullable=True, default="neo4j", comment="存储模式类型: neo4j / rag")
|
||||
rag_tags = Column(Text, nullable=True, comment="RAG模式下提取的标签列表(JSON格式)")
|
||||
rag_personas = Column(Text, nullable=True, comment="RAG模式下提取的人物形象列表(JSON格式)")
|
||||
rag_summary_updated_at = Column(DateTime, nullable=True, comment="RAG摘要/标签/人物形象最后更新时间")
|
||||
|
||||
# 与 App 的反向关系
|
||||
app = relationship(
|
||||
"App",
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
|
||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||
|
||||
|
||||
class MCPSourceChannel(StrEnum):
|
||||
"""MCP来源渠道枚举"""
|
||||
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
|
||||
MODELSCOPE = "modelscope" # ModelScope
|
||||
TOKENFLUX = "tokenflux" # TokenFlux
|
||||
LANGENG = "langeng" # 蓝耕科技
|
||||
AI_302 = "302ai" # 302.AI
|
||||
MCP_ROUTER = "mcp_router" # MCP Router
|
||||
SELF_HOSTED = "self_hosted" # 自建
|
||||
|
||||
|
||||
class MCPToolConfig(Base):
|
||||
"""MCP工具配置模型"""
|
||||
__tablename__ = "mcp_tool_configs"
|
||||
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
||||
|
||||
# 来源渠道
|
||||
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
|
||||
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
|
||||
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
|
||||
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
|
||||
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from app.models.app_model import App
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.app_model import App
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -35,11 +36,27 @@ class AppRepository:
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def get_apps_by_name(self, app_name: str, app_type: str, workspace_id: uuid.UUID) -> List[App]:
|
||||
try:
|
||||
stmt = select(App).where(
|
||||
App.name == app_name,
|
||||
App.workspace_id == workspace_id,
|
||||
App.type == app_type,
|
||||
App.is_active.is_(True),
|
||||
)
|
||||
apps = self.db.execute(stmt).scalars().all()
|
||||
return list(apps)
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询名称 {app_name} 应用异常: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_workspace_id(workspace_id)
|
||||
|
||||
|
||||
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
|
||||
@@ -220,6 +220,88 @@ class EndUserRepository:
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_rag_summary_tags(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
user_summary: str,
|
||||
rag_tags: str,
|
||||
rag_personas: str,
|
||||
) -> bool:
|
||||
"""更新RAG模式下的用户摘要、标签和人物形象缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
user_summary: 用户摘要文本
|
||||
rag_tags: 标签列表(JSON字符串)
|
||||
rag_personas: 人物形象列表(JSON字符串)
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.user_summary: user_summary,
|
||||
EndUser.rag_tags: rag_tags,
|
||||
EndUser.rag_personas: rag_personas,
|
||||
EndUser.rag_summary_updated_at: datetime.datetime.now(),
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG摘要/标签/人物形象缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新RAG摘要缓存")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的RAG摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_rag_insight(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
memory_insight: str,
|
||||
) -> bool:
|
||||
"""更新RAG模式下的记忆洞察缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_insight: 洞察文本
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.memory_insight: memory_insight,
|
||||
EndUser.memory_insight_updated_at: datetime.datetime.now(),
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的RAG洞察缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新RAG洞察缓存")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的RAG洞察缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户
|
||||
|
||||
|
||||
@@ -5,13 +5,22 @@ Implicit Emotions Storage Repository
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from typing import Optional, Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, not_, exists
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Generator, Optional
|
||||
|
||||
|
||||
class TimeFilterUnavailableError(Exception):
|
||||
"""redis_client 不可用,无法执行时间轴筛选。
|
||||
|
||||
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||
"""
|
||||
|
||||
import redis
|
||||
from sqlalchemy import exists, not_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -111,6 +120,88 @@ class ImplicitEmotionsStorageRepository:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_users_needing_refresh(self, redis_client: redis.StrictRedis, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
|
||||
|
||||
筛选逻辑:
|
||||
- 查询 implicit_emotions_storage 中所有用户的 end_user_id 和 updated_at
|
||||
- 从 Redis 读取 write_message:last_done:{end_user_id} 的时间戳
|
||||
- 若 Redis 中无记录(该用户从未写入过记忆),跳过
|
||||
- 若 last_done > updated_at,说明上次刷新后又有新记忆写入,需要刷新
|
||||
- 若 last_done <= updated_at,说明已是最新,跳过
|
||||
|
||||
Args:
|
||||
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB)
|
||||
batch_size: 每批次加载的数量
|
||||
|
||||
Raises:
|
||||
TimeFilterUnavailableError: redis_client 为 None 时抛出,调用方可捕获并回退到 get_all_user_ids
|
||||
|
||||
Yields:
|
||||
需要刷新的用户ID字符串
|
||||
"""
|
||||
if redis_client is None:
|
||||
raise TimeFilterUnavailableError("redis_client 不可用,无法执行时间轴筛选")
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(ImplicitEmotionsStorage.end_user_id, ImplicitEmotionsStorage.updated_at)
|
||||
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).all()
|
||||
if not batch:
|
||||
break
|
||||
|
||||
# 批量获取当前批次所有用户的 last_done 时间戳(一次网络往返)
|
||||
keys = [f"write_message:last_done:{end_user_id}" for end_user_id, _ in batch]
|
||||
|
||||
try:
|
||||
raw_values = redis_client.mget(keys)
|
||||
except RedisError as e:
|
||||
logger.error(
|
||||
f"Redis mget 操作失败: {e},当前批次降级为处理所有用户",
|
||||
extra={"offset": offset, "batch_size": len(batch)}
|
||||
)
|
||||
# Redis 操作失败,降级为返回当前批次所有用户
|
||||
yield from (end_user_id for end_user_id, _ in batch)
|
||||
offset += batch_size
|
||||
continue
|
||||
|
||||
for (end_user_id, updated_at), raw in zip(batch, raw_values):
|
||||
if raw is None:
|
||||
continue
|
||||
try:
|
||||
CST = timezone(timedelta(hours=8))
|
||||
last_done = datetime.fromisoformat(raw)
|
||||
# last_done 写入时已是 CST naive,直接使用,无需转换
|
||||
if last_done.tzinfo is not None:
|
||||
last_done = last_done.astimezone(CST).replace(tzinfo=None)
|
||||
|
||||
if updated_at is None:
|
||||
yield end_user_id
|
||||
continue
|
||||
# updated_at 数据库存的是 UTC naive,转为 CST naive 再比较
|
||||
if updated_at.tzinfo is None:
|
||||
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||
else:
|
||||
updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None)
|
||||
|
||||
if last_done > updated_at_cst:
|
||||
yield end_user_id
|
||||
except Exception as e:
|
||||
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
|
||||
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"get_users_needing_refresh 分批查询失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||
|
||||
@@ -124,7 +215,8 @@ class ImplicitEmotionsStorageRepository:
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
from sqlalchemy import cast, String as SAString
|
||||
from sqlalchemy import String as SAString
|
||||
from sqlalchemy import cast
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from app.models.user_model import User
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||
from app.schemas.workspace_schema import WorkspaceCreate
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -19,7 +22,7 @@ class WorkspaceRepository:
|
||||
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
||||
"""创建工作空间"""
|
||||
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_workspace = Workspace(
|
||||
name=workspace_data.name,
|
||||
@@ -34,7 +37,8 @@ class WorkspaceRepository:
|
||||
)
|
||||
self.db.add(db_workspace)
|
||||
self.db.flush()
|
||||
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||
db_logger.info(
|
||||
f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||
return db_workspace
|
||||
except Exception as e:
|
||||
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
|
||||
@@ -43,7 +47,7 @@ class WorkspaceRepository:
|
||||
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
|
||||
"""根据ID获取工作空间"""
|
||||
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
@@ -65,7 +69,7 @@ class WorkspaceRepository:
|
||||
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if workspace:
|
||||
@@ -89,7 +93,7 @@ class WorkspaceRepository:
|
||||
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
|
||||
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 首先获取用户信息以获取 tenant_id
|
||||
from app.models.user_model import User
|
||||
@@ -97,7 +101,7 @@ class WorkspaceRepository:
|
||||
if not user:
|
||||
db_logger.warning(f"用户不存在: user_id={user_id}")
|
||||
return []
|
||||
|
||||
|
||||
if user.is_superuser:
|
||||
# 超级用户获取对应tenantid所有工作空间
|
||||
workspaces = (
|
||||
@@ -109,7 +113,7 @@ class WorkspaceRepository:
|
||||
)
|
||||
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
|
||||
# 获取用户作为成员的工作空间
|
||||
member_workspaces = (
|
||||
self.db.query(Workspace)
|
||||
@@ -120,7 +124,7 @@ class WorkspaceRepository:
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
|
||||
return member_workspaces
|
||||
except Exception as e:
|
||||
@@ -130,7 +134,7 @@ class WorkspaceRepository:
|
||||
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
|
||||
"""获取租户的所有工作空间"""
|
||||
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
workspaces = (
|
||||
self.db.query(Workspace)
|
||||
@@ -144,14 +148,32 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||
def get_workspaces_by_name(self, tenant_id: uuid.UUID, workspace_name: str) -> List[Workspace]:
|
||||
try:
|
||||
stmt = (
|
||||
select(Workspace)
|
||||
.where(
|
||||
Workspace.tenant_id == tenant_id,
|
||||
Workspace.name == workspace_name,
|
||||
Workspace.is_active.is_(True)
|
||||
)
|
||||
)
|
||||
|
||||
workspaces = self.db.execute(stmt).scalars().all()
|
||||
return list(workspaces)
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间失败: workspace_name={workspace_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID,
|
||||
role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||
"""添加工作空间成员"""
|
||||
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
||||
|
||||
|
||||
try:
|
||||
db_member = WorkspaceMember(
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
role=role
|
||||
)
|
||||
self.db.add(db_member)
|
||||
@@ -165,7 +187,7 @@ class WorkspaceRepository:
|
||||
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
"""获取工作空间成员"""
|
||||
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.user_id == user_id,
|
||||
@@ -173,7 +195,8 @@ class WorkspaceRepository:
|
||||
WorkspaceMember.is_active.is_(True),
|
||||
).first()
|
||||
if member:
|
||||
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||
db_logger.debug(
|
||||
f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
|
||||
return member
|
||||
@@ -199,7 +222,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
|
||||
"""按成员ID获取工作空间成员,并预加载 user 与 workspace 关系"""
|
||||
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
|
||||
@@ -214,7 +237,8 @@ class WorkspaceRepository:
|
||||
.first()
|
||||
)
|
||||
if member:
|
||||
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||
db_logger.debug(
|
||||
f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||
else:
|
||||
db_logger.debug(f"成员不存在: member_id={member_id}")
|
||||
return member
|
||||
@@ -222,7 +246,8 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[
|
||||
WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
WorkspaceMember.workspace_id == workspace_id,
|
||||
@@ -255,7 +280,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
@@ -271,7 +296,7 @@ class WorkspaceRepository:
|
||||
except Exception as e:
|
||||
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||
try:
|
||||
member = self.db.query(WorkspaceMember).filter(
|
||||
@@ -288,12 +313,18 @@ class WorkspaceRepository:
|
||||
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 保持向后兼容的函数
|
||||
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspace_by_id(workspace_id)
|
||||
|
||||
|
||||
def get_workspaces_by_name(db: Session, tenant_id: uuid.UUID, name: str) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_name(tenant_id, name)
|
||||
|
||||
|
||||
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_workspaces_by_user(user_id)
|
||||
@@ -315,7 +346,7 @@ def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UU
|
||||
|
||||
|
||||
def add_member_to_workspace(
|
||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||
) -> WorkspaceMember:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.add_member(workspace_id, user_id, role)
|
||||
@@ -325,39 +356,43 @@ def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[Works
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_members_by_workspace(workspace_id)
|
||||
|
||||
|
||||
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.get_member_by_id(member_id)
|
||||
|
||||
|
||||
def update_member_role_in_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role(workspace_id, user_id, role)
|
||||
|
||||
|
||||
def remove_member_from_workspace(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.deactivate_member(workspace_id, user_id)
|
||||
|
||||
|
||||
def remove_member_from_workspace_by_id(
|
||||
db: Session,
|
||||
member_id: uuid.UUID,
|
||||
db: Session,
|
||||
member_id: uuid.UUID,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.delete_member_by_id(member_id)
|
||||
|
||||
|
||||
def update_member_role_by_id(
|
||||
db: Session,
|
||||
id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
db: Session,
|
||||
id: uuid.UUID,
|
||||
role: WorkspaceRole,
|
||||
) -> Optional[WorkspaceMember]:
|
||||
repo = WorkspaceRepository(db)
|
||||
return repo.update_member_role_by_id(id, role)
|
||||
|
||||
@@ -45,11 +45,19 @@ class FileInput(BaseModel):
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
_content = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
def set_content(self, content: bytes):
|
||||
self._content = content
|
||||
|
||||
def get_content(self) -> bytes | None:
|
||||
return self._content
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
|
||||
@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
|
||||
health_status: str = "unknown"
|
||||
error_message: Optional[str] = None
|
||||
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
|
||||
tool_type: ToolType
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)")
|
||||
|
||||
|
||||
class ToolUpdateRequest(BaseModel):
|
||||
|
||||
445
api/app/services/app_dsl_service.py
Normal file
445
api/app/services/app_dsl_service.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""应用 DSL 导入导出服务"""
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.models import AgentConfig, MultiAgentConfig
|
||||
from app.models.app_model import App, AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.models_model import ModelConfig
|
||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
|
||||
|
||||
class AppDslService:
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
def export_dsl(self, app_id: uuid.UUID, release_id: Optional[uuid.UUID] = None) -> tuple[str, str]:
|
||||
"""构建应用 DSL yaml 字符串,返回 (yaml_str, filename)"""
|
||||
app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
|
||||
if not app:
|
||||
raise ResourceNotFoundException("应用", str(app_id))
|
||||
|
||||
meta = {
|
||||
"version": settings.SYSTEM_VERSION,
|
||||
"platform": "MemoryBear",
|
||||
"exported_at": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
app_meta = {
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"icon": app.icon,
|
||||
"icon_type": app.icon_type,
|
||||
"type": app.type,
|
||||
"tags": app.tags or [],
|
||||
}
|
||||
|
||||
if release_id is not None:
|
||||
return self._export_release(app, release_id, meta, app_meta)
|
||||
|
||||
return self._export_draft(app, meta, app_meta)
|
||||
|
||||
def _export_release(self, app: App, release_id: uuid.UUID, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
release = self.db.query(AppRelease).filter(
|
||||
AppRelease.app_id == app.id,
|
||||
AppRelease.id == release_id,
|
||||
AppRelease.is_active.is_(True)
|
||||
).first()
|
||||
if not release:
|
||||
raise ResourceNotFoundException("版本", str(release_id))
|
||||
|
||||
meta["release_version"] = release.version
|
||||
meta["release_name"] = release.version_name
|
||||
app_meta["name"] = release.name
|
||||
app_meta["description"] = release.description
|
||||
config_key = {
|
||||
AppType.AGENT: "agent_config",
|
||||
AppType.MULTI_AGENT: "multi_agent_config",
|
||||
AppType.WORKFLOW: "workflow"
|
||||
}.get(app.type, "config")
|
||||
config_data = self._enrich_release_config(app.type, release.config or {})
|
||||
dsl = {**meta, "app": app_meta, config_key: config_data}
|
||||
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
|
||||
|
||||
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
|
||||
if app_type == AppType.AGENT:
|
||||
enriched = {**cfg}
|
||||
if "default_model_config_id" in cfg:
|
||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||
if "knowledge_retrieval" in cfg:
|
||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||
if "tools" in cfg:
|
||||
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
||||
return enriched
|
||||
if app_type == AppType.MULTI_AGENT:
|
||||
enriched = {**cfg}
|
||||
if "default_model_config_id" in cfg:
|
||||
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||
if "master_agent_id" in cfg:
|
||||
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
|
||||
if "sub_agents" in cfg:
|
||||
enriched["sub_agents"] = self._enrich_sub_agents(cfg["sub_agents"])
|
||||
if "routing_rules" in cfg:
|
||||
enriched["routing_rules"] = [
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||
]
|
||||
return enriched
|
||||
return cfg
|
||||
|
||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
if app.type == AppType.WORKFLOW:
|
||||
config = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "workflow": config_data}
|
||||
|
||||
elif app.type == AppType.AGENT:
|
||||
config = self.db.query(AgentConfig).filter(AgentConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"system_prompt": config.system_prompt if config else None,
|
||||
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||
"knowledge_retrieval": self._enrich_knowledge_retrieval(config.knowledge_retrieval) if config else None,
|
||||
"memory": config.memory if config else None,
|
||||
"variables": config.variables if config else [],
|
||||
"tools": self._enrich_tools(config.tools) if config else [],
|
||||
"skills": config.skills if config else {},
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
config = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app.id).first()
|
||||
config_data = {
|
||||
"orchestration_mode": config.orchestration_mode if config else None,
|
||||
"master_agent_name": config.master_agent_name if config else None,
|
||||
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||
"master_agent_ref": self._release_ref(config.master_agent_id) if config else None,
|
||||
"sub_agents": self._enrich_sub_agents(config.sub_agents) if config else [],
|
||||
"routing_rules": [
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (config.routing_rules or [])
|
||||
] if config else [],
|
||||
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"aggregation_strategy": config.aggregation_strategy if config else "merge",
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "multi_agent_config": config_data}
|
||||
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app.type}", BizCode.BAD_REQUEST)
|
||||
|
||||
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{app.name}.yaml"
|
||||
|
||||
def _to_dict(self, value):
|
||||
"""将 Pydantic 对象转为普通 dict,供 yaml.dump 安全序列化"""
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def _model_ref(self, model_config_id) -> Optional[dict]:
|
||||
if not model_config_id:
|
||||
return None
|
||||
m = self.db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
|
||||
return {"id": str(model_config_id), "name": m.name, "provider": m.provider, "type": m.type} if m else {"id": str(model_config_id)}
|
||||
|
||||
def _kb_ref(self, kb_id) -> Optional[dict]:
|
||||
if not kb_id:
|
||||
return None
|
||||
kb = self.db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||
return {"id": str(kb_id), "name": kb.name} if kb else {"id": str(kb_id)}
|
||||
|
||||
def _tool_ref(self, tool_id) -> Optional[dict]:
|
||||
if not tool_id:
|
||||
return None
|
||||
t = self.db.query(ToolConfigModel).filter(ToolConfigModel.id == tool_id).first()
|
||||
return {"id": str(tool_id), "name": t.name, "tool_type": t.tool_type} if t else {"id": str(tool_id)}
|
||||
|
||||
def _enrich_knowledge_retrieval(self, kr: Optional[dict]) -> Optional[dict]:
|
||||
if not kr:
|
||||
return kr
|
||||
kbs = [{**kb, "_ref": self._kb_ref(kb.get("kb_id"))} for kb in kr.get("knowledge_bases", [])]
|
||||
return {**kr, "knowledge_bases": kbs}
|
||||
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||
if not agent_id:
|
||||
return None
|
||||
a = self.db.query(App).filter(App.id == agent_id).first()
|
||||
return {"id": str(agent_id), "name": a.name} if a else {"id": str(agent_id)}
|
||||
|
||||
def _release_ref(self, release_id) -> Optional[dict]:
|
||||
if not release_id:
|
||||
return None
|
||||
r = self.db.query(AppRelease).filter(AppRelease.id == release_id).first()
|
||||
return {"id": str(release_id), "name": r.name, "version": r.version, "app_id": str(r.app_id)} if r else {"id": str(release_id)}
|
||||
|
||||
def _enrich_sub_agents(self, sub_agents: list) -> list:
|
||||
return [{**s, "_ref": self._agent_ref(s.get("agent_id"))} for s in (sub_agents or [])]
|
||||
|
||||
# ==================== 导入 ====================
|
||||
|
||||
def import_dsl(
|
||||
self,
|
||||
dsl: dict,
|
||||
workspace_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
) -> tuple[App, list[str]]:
|
||||
"""解析 DSL,创建应用及配置,返回 (new_app, warnings)"""
|
||||
app_meta = dsl.get("app", {})
|
||||
app_type = app_meta.get("type")
|
||||
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.BAD_REQUEST)
|
||||
|
||||
warnings: list[str] = []
|
||||
now = datetime.datetime.now()
|
||||
|
||||
new_app = App(
|
||||
id=uuid.uuid4(),
|
||||
workspace_id=workspace_id,
|
||||
created_by=user_id,
|
||||
name=self._unique_app_name(app_meta.get("name", "导入应用"), workspace_id, app_type),
|
||||
description=app_meta.get("description"),
|
||||
icon=app_meta.get("icon"),
|
||||
icon_type=app_meta.get("icon_type"),
|
||||
type=app_type,
|
||||
visibility="private",
|
||||
status="draft",
|
||||
tags=app_meta.get("tags", []),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(new_app)
|
||||
self.db.flush()
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
cfg = dsl.get("agent_config") or {}
|
||||
self.db.add(AgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
system_prompt=cfg.get("system_prompt"),
|
||||
model_parameters=cfg.get("model_parameters"),
|
||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
|
||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||
variables=cfg.get("variables", []),
|
||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||
skills=cfg.get("skills", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
))
|
||||
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
cfg = dsl.get("multi_agent_config") or {}
|
||||
self.db.add(MultiAgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
|
||||
master_agent_name=cfg.get("master_agent_name"),
|
||||
model_parameters=cfg.get("model_parameters"),
|
||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||
master_agent_id=self._resolve_release(cfg.get("master_agent_ref"), warnings),
|
||||
sub_agents=self._resolve_sub_agents(cfg.get("sub_agents", []), warnings),
|
||||
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
|
||||
execution_config=cfg.get("execution_config", {}),
|
||||
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
))
|
||||
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
adapter = MemoryBearAdapter(dsl)
|
||||
if not adapter.validate_config():
|
||||
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
|
||||
result = adapter.parse_workflow()
|
||||
for e in result.errors:
|
||||
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
|
||||
for w in result.warnings:
|
||||
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
|
||||
wf = dsl.get("workflow") or {}
|
||||
WorkflowService(self.db).create_workflow_config(
|
||||
app_id=new_app.id,
|
||||
nodes=[n.model_dump() for n in result.nodes],
|
||||
edges=[e.model_dump() for e in result.edges],
|
||||
variables=[v.model_dump() for v in result.variables],
|
||||
execution_config=wf.get("execution_config", {}),
|
||||
triggers=wf.get("triggers", []),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(new_app)
|
||||
return new_app, warnings
|
||||
|
||||
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
||||
existing = {r[0] for r in self.db.query(App.name).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.type == app_type,
|
||||
App.is_active.is_(True)
|
||||
).all()}
|
||||
if name not in existing:
|
||||
return name
|
||||
counter = 1
|
||||
while f"{name}({counter})" in existing:
|
||||
counter += 1
|
||||
return f"{name}({counter})"
|
||||
|
||||
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
|
||||
if not ref:
|
||||
return None
|
||||
q = self.db.query(ModelConfig).filter(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.name == ref.get("name"),
|
||||
ModelConfig.is_active.is_(True)
|
||||
)
|
||||
if ref.get("provider"):
|
||||
q = q.filter(ModelConfig.provider == ref["provider"])
|
||||
if ref.get("type"):
|
||||
q = q.filter(ModelConfig.type == ref["type"])
|
||||
m = q.first()
|
||||
if not m:
|
||||
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return m.id if m else None
|
||||
|
||||
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
kb = self.db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.name == ref.get("name")
|
||||
).first()
|
||||
if not kb:
|
||||
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return str(kb.id) if kb else None
|
||||
|
||||
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
q = self.db.query(ToolConfigModel).filter(
|
||||
ToolConfigModel.tenant_id == tenant_id,
|
||||
ToolConfigModel.name == ref.get("name")
|
||||
)
|
||||
if ref.get("tool_type"):
|
||||
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
|
||||
t = q.first()
|
||||
if not t:
|
||||
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return str(t.id) if t else None
|
||||
|
||||
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
|
||||
if not ref:
|
||||
return None
|
||||
r = self.db.query(AppRelease).filter(
|
||||
AppRelease.app_id == ref.get("app_id"),
|
||||
AppRelease.version == ref.get("version"),
|
||||
AppRelease.is_active.is_(True)
|
||||
).first()
|
||||
if not r:
|
||||
warnings.append(f"主 Agent 发布版本 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
return r.id if r else None
|
||||
|
||||
def _resolve_sub_agents(self, sub_agents: list, warnings: list) -> list:
|
||||
result = []
|
||||
for s in (sub_agents or []):
|
||||
ref = s.get("_ref")
|
||||
entry = {k: v for k, v in s.items() if k != "_ref"}
|
||||
if ref:
|
||||
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||
if not a:
|
||||
warnings.append(f"子 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
entry["agent_id"] = str(a.id) if a else None
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
def _resolve_routing_rules(self, rules: Optional[list], warnings: list) -> Optional[list]:
|
||||
if rules is None:
|
||||
return None
|
||||
result = []
|
||||
for r in rules:
|
||||
ref = r.get("_ref")
|
||||
entry = {k: v for k, v in r.items() if k != "_ref"}
|
||||
if ref:
|
||||
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||
if not a:
|
||||
warnings.append(f"路由目标 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||
entry["target_agent_id"] = str(a.id) if a else None
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||
if not kr:
|
||||
return kr
|
||||
resolved_kbs = []
|
||||
for kb in kr.get("knowledge_bases", []):
|
||||
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
|
||||
entry = {k: v for k, v in kb.items() if k != "_ref"}
|
||||
resolved_id = self._resolve_kb(ref, workspace_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["kb_id"] = resolved_id
|
||||
resolved_kbs.append(entry)
|
||||
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
|
||||
|
||||
def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||
if not memory:
|
||||
return memory
|
||||
config_id = memory.get("memory_config_id") or memory.get("memory_content")
|
||||
if not config_id:
|
||||
return memory
|
||||
try:
|
||||
config_uuid = uuid.UUID(str(config_id))
|
||||
except (ValueError, AttributeError):
|
||||
exists = self.db.query(MemoryConfigModel).filter(
|
||||
MemoryConfigModel.config_id_old == int(config_id),
|
||||
MemoryConfigModel.workspace_id == workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
exists = self.db.query(MemoryConfigModel).filter(
|
||||
MemoryConfigModel.config_id == config_uuid,
|
||||
MemoryConfigModel.workspace_id == workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
|
||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||
result = []
|
||||
for t in (tools or []):
|
||||
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
|
||||
entry = {k: v for k, v in t.items() if k != "_ref"}
|
||||
resolved_id = self._resolve_tool(ref, tenant_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["tool_id"] = resolved_id
|
||||
result.append(entry)
|
||||
return result
|
||||
@@ -33,7 +33,7 @@ from app.models import (
|
||||
Workspace,
|
||||
)
|
||||
from app.models.app_model import AppStatus, AppType
|
||||
from app.repositories.app_repository import get_apps_by_id
|
||||
from app.repositories.app_repository import get_apps_by_id, AppRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
@@ -59,6 +59,7 @@ class AppService:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.app_repo = AppRepository(self.db)
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
@@ -521,6 +522,9 @@ class AppService:
|
||||
"创建应用",
|
||||
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
|
||||
)
|
||||
apps = self.app_repo.get_apps_by_name(data.name, data.type, workspace_id)
|
||||
if apps:
|
||||
raise BusinessException(message="已存在同名应用", code=BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
@@ -1368,6 +1372,15 @@ class AppService:
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
miss_params = []
|
||||
if agent_cfg.default_model_config_id is None:
|
||||
miss_params.append("model config")
|
||||
|
||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||
miss_params.append("memory config")
|
||||
if miss_params:
|
||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||
|
||||
config = {
|
||||
"system_prompt": agent_cfg.system_prompt,
|
||||
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),
|
||||
|
||||
@@ -1165,6 +1165,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
@@ -1179,10 +1180,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
if not app:
|
||||
logger.warning(f"App not found: {app_id}")
|
||||
raise ValueError(f"应用不存在: {app_id}")
|
||||
|
||||
if not app.current_release_id:
|
||||
logger.warning(f"No current release for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
# TODO: temp fix for draft run
|
||||
# if not app.current_release_id:
|
||||
# logger.warning(f"No current release for app: {app_id}")
|
||||
# raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||||
memory_config_id_to_use = end_user.memory_config_id
|
||||
@@ -1223,7 +1224,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
if legacy_config_id:
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.models.memory_config_model import (
|
||||
MemoryConfig as MemoryConfigModel,
|
||||
)
|
||||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||||
|
||||
if existing_config:
|
||||
@@ -1257,7 +1260,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(app.current_release_id),
|
||||
"release_id": str(app.current_release_id) if app.current_release_id else None,
|
||||
"memory_config_id": memory_config_id,
|
||||
"workspace_id": str(app.workspace_id)
|
||||
}
|
||||
|
||||
@@ -107,29 +107,19 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
)
|
||||
|
||||
|
||||
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
|
||||
# 使用懒加载函数避免模块级循环导入
|
||||
def _get_builtin_pruning_scenes() -> set:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
|
||||
return set(SceneConfigRegistry.get_all_scenes())
|
||||
|
||||
|
||||
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
||||
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
|
||||
"""从 ontology_class 表加载场景类型名称列表,用于注入提示词。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
scene_id: 本体场景 UUID
|
||||
pruning_scene: 语义剪枝场景名称
|
||||
pruning_scene: 语义剪枝场景名称(保留参数,暂未使用)
|
||||
|
||||
Returns:
|
||||
class_name 字符串列表,或 None(内置场景 / 无数据时)
|
||||
class_name 字符串列表,或 None(无数据时)
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
# 内置场景走 SceneConfigRegistry,不需要注入类型列表
|
||||
if pruning_scene in _get_builtin_pruning_scenes():
|
||||
return None
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
repo = OntologyClassRepository(db)
|
||||
|
||||
@@ -535,7 +535,8 @@ def get_users_total_chunk_batch(
|
||||
|
||||
def get_rag_content(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
@@ -543,9 +544,9 @@ def get_rag_content(
|
||||
先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id,
|
||||
然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容,
|
||||
接着对获取的内容进行提取,只要page_content的内容,
|
||||
最后返回数据
|
||||
最后返回分页数据
|
||||
"""
|
||||
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
@@ -562,63 +563,76 @@ def get_rag_content(
|
||||
if not documents:
|
||||
business_logger.warning(f"未找到文件: {file_name}")
|
||||
return {
|
||||
"total": 0,
|
||||
"contents": []
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": 0,
|
||||
"hasnext": False,
|
||||
},
|
||||
"items": []
|
||||
}
|
||||
|
||||
business_logger.info(f"找到 {len(documents)} 个文档记录")
|
||||
|
||||
# 3. 获取所有chunks的page_content
|
||||
all_contents = []
|
||||
total_chunks = 0
|
||||
# 3. 按全局偏移量计算当前页数据
|
||||
# 全局偏移范围:[offset_start, offset_end)
|
||||
offset_start = (page - 1) * pagesize
|
||||
offset_end = offset_start + pagesize
|
||||
|
||||
global_total = 0 # 所有文档的 chunk 总数
|
||||
page_contents = [] # 当前页的内容
|
||||
|
||||
for document in documents:
|
||||
try:
|
||||
# 获取知识库信息
|
||||
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
|
||||
if not kb:
|
||||
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
|
||||
continue
|
||||
|
||||
# 初始化向量服务
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb)
|
||||
|
||||
# 获取该文档的所有chunks(分页获取)
|
||||
page = 1
|
||||
pagesize = 100 # 每页100条
|
||||
# 先用 pagesize=1 获取该文档的 chunk 总数
|
||||
doc_total, _ = vector_service.search_by_segment(
|
||||
document_id=str(document.id),
|
||||
query=None,
|
||||
pagesize=1,
|
||||
page=1,
|
||||
asc=True
|
||||
)
|
||||
|
||||
while True:
|
||||
total, items = vector_service.search_by_segment(
|
||||
doc_offset_start = global_total # 该文档在全局中的起始偏移
|
||||
doc_offset_end = global_total + doc_total # 该文档在全局中的结束偏移
|
||||
global_total += doc_total
|
||||
|
||||
# 当前页与该文档无交集,跳过
|
||||
if doc_offset_end <= offset_start or doc_offset_start >= offset_end:
|
||||
continue
|
||||
|
||||
# 计算需要从该文档取的局部范围
|
||||
local_start = max(offset_start - doc_offset_start, 0)
|
||||
local_end = min(offset_end - doc_offset_start, doc_total)
|
||||
need_count = local_end - local_start
|
||||
|
||||
# 换算成 ES 分页参数(ES page 从1开始)
|
||||
es_page = (local_start // pagesize) + 1
|
||||
es_offset_in_page = local_start % pagesize
|
||||
|
||||
fetched = []
|
||||
while len(fetched) < es_offset_in_page + need_count:
|
||||
_, items = vector_service.search_by_segment(
|
||||
document_id=str(document.id),
|
||||
query=None,
|
||||
pagesize=pagesize,
|
||||
page=page,
|
||||
page=es_page,
|
||||
asc=True
|
||||
)
|
||||
|
||||
if not items:
|
||||
break
|
||||
|
||||
# 提取page_content
|
||||
for item in items:
|
||||
all_contents.append(item.page_content)
|
||||
total_chunks += 1
|
||||
|
||||
# # 如果达到limit限制,直接返回
|
||||
# if limit > 0 and total_chunks >= limit:
|
||||
# business_logger.info(f"已达到limit限制: {limit}")
|
||||
# return {
|
||||
# "total": total_chunks,
|
||||
# "contents": all_contents[:limit]
|
||||
# }
|
||||
|
||||
# 检查是否还有下一页
|
||||
if page * pagesize >= total:
|
||||
break
|
||||
|
||||
page += 1
|
||||
fetched.extend(items)
|
||||
es_page += 1
|
||||
|
||||
business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks")
|
||||
slice_items = fetched[es_offset_in_page: es_offset_in_page + need_count]
|
||||
page_contents.extend([item.page_content for item in slice_items])
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
@@ -626,11 +640,16 @@ def get_rag_content(
|
||||
|
||||
# 4. 返回结果
|
||||
result = {
|
||||
"total": total_chunks,
|
||||
"contents": all_contents[:limit] if limit > 0 else all_contents
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": global_total,
|
||||
"hasnext": offset_end < global_total,
|
||||
},
|
||||
"items": page_contents
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])} 条")
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -646,59 +665,26 @@ async def get_chunk_summary_and_tags(
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的总结、标签和人物形象
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
max_tags: 最大标签数量
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含summary、tags和personas的字典
|
||||
纯读库:从end_user表返回RAG摘要、标签和人物形象缓存。
|
||||
无数据时返回空结构,不触发LLM生成。
|
||||
"""
|
||||
business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"summary": "暂无内容",
|
||||
"tags": [],
|
||||
"personas": []
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona
|
||||
|
||||
# 3. 并发生成摘要、提取标签和人物形象
|
||||
import asyncio
|
||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
|
||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
|
||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
|
||||
|
||||
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
||||
|
||||
# 4. 格式化标签数据
|
||||
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
|
||||
|
||||
result = {
|
||||
"summary": summary,
|
||||
"tags": tags,
|
||||
"personas": personas
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
import json
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
business_logger.info(f"读取chunk摘要/标签/人物形象缓存: end_user_id={end_user_id}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
return {"summary": "", "tags": [], "personas": [], "generated": False}
|
||||
|
||||
return {
|
||||
"summary": end_user.user_summary or "",
|
||||
"tags": json.loads(end_user.rag_tags) if end_user.rag_tags else [],
|
||||
"personas": json.loads(end_user.rag_personas) if end_user.rag_personas else [],
|
||||
"generated": bool(end_user.user_summary),
|
||||
}
|
||||
|
||||
|
||||
async def get_chunk_insight(
|
||||
@@ -708,43 +694,98 @@ async def get_chunk_insight(
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的洞察分析
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含insight的字典
|
||||
纯读库:从end_user表返回RAG洞察缓存。
|
||||
无数据时返回空结构,不触发LLM生成。
|
||||
"""
|
||||
business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"insight": "暂无足够数据生成洞察报告"
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_insight
|
||||
|
||||
# 3. 生成洞察
|
||||
insight = await generate_chunk_insight(chunks, max_chunks=limit)
|
||||
|
||||
result = {
|
||||
"insight": insight
|
||||
}
|
||||
|
||||
business_logger.info("成功获取chunk洞察")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
business_logger.info(f"读取chunk洞察缓存: end_user_id={end_user_id}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
return {"insight": "", "behavior_pattern": "", "key_findings": "", "growth_trajectory": "", "generated": False}
|
||||
|
||||
return {
|
||||
"insight": end_user.memory_insight or "",
|
||||
"behavior_pattern": end_user.behavior_pattern or "",
|
||||
"key_findings": end_user.key_findings or "",
|
||||
"growth_trajectory": end_user.growth_trajectory or "",
|
||||
"generated": bool(end_user.memory_insight),
|
||||
}
|
||||
|
||||
|
||||
async def generate_rag_profile(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
max_tags: int,
|
||||
db: Session,
|
||||
current_user: User,
|
||||
) -> dict:
|
||||
"""
|
||||
生产接口:为RAG存储模式的end_user全量重新生成并持久化完整画像数据。
|
||||
每次调用都会重新生成,覆盖已有数据。
|
||||
|
||||
生成内容:
|
||||
- user_summary / rag_tags / rag_personas
|
||||
- memory_insight / behavior_pattern / key_findings / growth_trajectory
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.core.rag_utils import (
|
||||
generate_chunk_summary,
|
||||
extract_chunk_tags,
|
||||
extract_chunk_persona,
|
||||
generate_chunk_insight_sections,
|
||||
)
|
||||
|
||||
business_logger.info(f"开始生产RAG画像: end_user_id={end_user_id}, 操作者: {current_user.username}")
|
||||
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
|
||||
if not end_user:
|
||||
raise ValueError(f"end_user {end_user_id} 不存在")
|
||||
|
||||
rag_content = get_rag_content(end_user_id, page=1, pagesize=limit, db=db, current_user=current_user)
|
||||
chunks = rag_content.get("items", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容,无法生产RAG画像: end_user_id={end_user_id}")
|
||||
raise ValueError("暂无chunk内容,无法生成画像")
|
||||
|
||||
summary, tags_with_freq, personas, insight_sections = await asyncio.gather(
|
||||
generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id),
|
||||
extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id),
|
||||
extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id),
|
||||
generate_chunk_insight_sections(chunks, max_chunks=limit, end_user_id=end_user_id),
|
||||
)
|
||||
|
||||
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
|
||||
|
||||
repo.update_rag_summary_tags(
|
||||
end_user_id=end_user.id,
|
||||
user_summary=summary,
|
||||
rag_tags=json.dumps(tags, ensure_ascii=False),
|
||||
rag_personas=json.dumps(personas, ensure_ascii=False),
|
||||
)
|
||||
|
||||
repo.update_memory_insight(
|
||||
end_user_id=end_user.id,
|
||||
memory_insight=insight_sections.get("memory_insight", ""),
|
||||
behavior_pattern=insight_sections.get("behavior_pattern", ""),
|
||||
key_findings=insight_sections.get("key_findings", ""),
|
||||
growth_trajectory=insight_sections.get("growth_trajectory", ""),
|
||||
)
|
||||
|
||||
business_logger.info(f"RAG画像生产完成: end_user_id={end_user_id}, tags={len(tags)}, personas={len(personas)}")
|
||||
|
||||
return {
|
||||
"end_user_id": end_user_id,
|
||||
"summary_length": len(summary),
|
||||
"tags_count": len(tags),
|
||||
"personas_count": len(personas),
|
||||
"insight_generated": bool(insight_sections.get("memory_insight")),
|
||||
}
|
||||
@@ -8,32 +8,42 @@
|
||||
- Bedrock/Anthropic: 仅支持 base64 格式
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import PyPDF2
|
||||
import httpx
|
||||
import magic
|
||||
from docx import Document
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
TEXT_MIME = ['text/plain', 'text/x-markdown']
|
||||
PDF_MIME = ['application/pdf']
|
||||
DOC_MIME = [
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
]
|
||||
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@@ -43,7 +53,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@@ -56,7 +66,7 @@ class MultimodalFormatStrategy(ABC):
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -70,7 +80,13 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
@@ -98,44 +114,37 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
content_type = magic.from_buffer(content, mime=True)
|
||||
media_type = content_type if content_type.startswith("image/") else "image/jpeg"
|
||||
base64_data = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
@@ -152,7 +161,12 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self, file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
@@ -178,7 +192,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
return {
|
||||
"type": "image_url",
|
||||
@@ -194,7 +208,13 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def format_audio(
|
||||
self,
|
||||
file_type: str,
|
||||
url: str,
|
||||
content: bytes | None = None,
|
||||
transcription: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
@@ -208,31 +228,35 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
content_type = magic.from_buffer(audio_data, mime=True)
|
||||
file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
@@ -262,7 +286,8 @@ PROVIDER_STRATEGIES = {
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
@@ -305,10 +330,9 @@ class MultimodalService:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file, strategy)
|
||||
@@ -355,7 +379,7 @@ class MultimodalService:
|
||||
"""
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
return await strategy.format_image(url, content=file.get_content())
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -415,11 +439,13 @@ class MultimodalService:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -454,7 +480,7 @@ class MultimodalService:
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
@@ -500,8 +526,6 @@ class MultimodalService:
|
||||
return file.url
|
||||
else:
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
@@ -519,66 +543,44 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
file: 文件输入
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in TEXT_MIME:
|
||||
return file_content.decode("utf-8")
|
||||
elif file_mime_type in PDF_MIME:
|
||||
return await self._extract_pdf_text(file_content)
|
||||
elif file_mime_type in DOC_MIME:
|
||||
return await self._extract_word_text(file_content)
|
||||
else:
|
||||
return f"[Unsupported file type: {file_mime_type}]"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
logger.error(f"Failed to load file. - {e}")
|
||||
return "[Failed to load file.]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
async def _extract_pdf_text(file_content: bytes) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
@@ -588,17 +590,11 @@ class MultimodalService:
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
async def _extract_word_text(file_content: bytes) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
|
||||
@@ -120,7 +120,8 @@ async def run_pilot_extraction(
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"llm_model_id": str(memory_config.llm_model_id),
|
||||
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
"ontology_classes": memory_config.ontology_classes,
|
||||
}
|
||||
config = PruningConfig(**pruning_config_dict)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class ToolService:
|
||||
"""检查工具名称是否重复"""
|
||||
query = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type.value,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -93,7 +93,44 @@ class ToolService:
|
||||
if query.first():
|
||||
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
def create_tool(
|
||||
def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]):
|
||||
"""检查MCP工具是否重复:市场来源按market_id+market_config_id+mcp_service_id判断(名称无关),自建按name+tool_type判断"""
|
||||
from app.models.tool_model import MCPSourceChannel
|
||||
source_channel = config.get("source_channel")
|
||||
is_market_source = (
|
||||
source_channel is not None
|
||||
and source_channel != MCPSourceChannel.SELF_HOSTED
|
||||
)
|
||||
if is_market_source:
|
||||
exists = (
|
||||
self.db.query(ToolConfig)
|
||||
.join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id)
|
||||
.filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
MCPToolConfig.source_channel == source_channel,
|
||||
MCPToolConfig.market_id == config.get("market_id"),
|
||||
MCPToolConfig.market_config_id == config.get("market_config_id"),
|
||||
MCPToolConfig.mcp_service_id == config.get("mcp_service_id"),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if exists:
|
||||
raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME)
|
||||
else:
|
||||
exists = (
|
||||
self.db.query(ToolConfig)
|
||||
.filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if exists:
|
||||
raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
async def create_tool(
|
||||
self,
|
||||
name: str,
|
||||
tool_type: ToolType,
|
||||
@@ -106,7 +143,19 @@ class ToolService:
|
||||
"""创建工具"""
|
||||
if tool_type == ToolType.BUILTIN:
|
||||
raise ValueError("内置工具不允许创建")
|
||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||
|
||||
cfg = config or {}
|
||||
if tool_type == ToolType.MCP:
|
||||
self._check_mcp_duplicate(name, tool_type, tenant_id, cfg)
|
||||
# 创建前测试连接
|
||||
test_result = await self._test_mcp_connection_by_config(cfg)
|
||||
if not test_result["success"]:
|
||||
raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER)
|
||||
# 将发现的工具列表写回 config
|
||||
if "available_tools" in test_result:
|
||||
cfg["available_tools"] = test_result["available_tools"]
|
||||
else:
|
||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||
|
||||
try:
|
||||
# 创建基础配置
|
||||
@@ -117,19 +166,22 @@ class ToolService:
|
||||
tool_type=tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
status=ToolStatus.AVAILABLE.value,
|
||||
config_data=config or {},
|
||||
config_data=cfg,
|
||||
tags=tags
|
||||
)
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建类型特定配置
|
||||
self._create_type_config(tool_config, config or {})
|
||||
self._create_type_config(tool_config, cfg)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"工具创建成功: {tool_config.id}")
|
||||
return str(tool_config.id)
|
||||
|
||||
except BusinessException:
|
||||
self.db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"创建工具失败: {e}")
|
||||
@@ -910,7 +962,11 @@ class ToolService:
|
||||
config_data.update({
|
||||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||||
"health_status": mcp_config.health_status,
|
||||
"available_tools": available_tools_display
|
||||
"available_tools": available_tools_display,
|
||||
"source_channel": mcp_config.source_channel,
|
||||
"market_id": mcp_config.market_id,
|
||||
"market_config_id": mcp_config.market_config_id,
|
||||
"mcp_service_id": mcp_config.mcp_service_id
|
||||
})
|
||||
|
||||
return ToolInfo(
|
||||
@@ -965,7 +1021,11 @@ class ToolService:
|
||||
id=tool_config.id,
|
||||
server_url=config.get("server_url"),
|
||||
connection_config=config.get("connection_config", {}),
|
||||
available_tools=config.get("available_tools", [])
|
||||
available_tools=config.get("available_tools", []),
|
||||
source_channel=config.get("source_channel", "self_hosted"),
|
||||
market_id=config.get("market_id"),
|
||||
market_config_id=config.get("market_config_id"),
|
||||
mcp_service_id=config.get("mcp_service_id"),
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
@@ -1018,6 +1078,14 @@ class ToolService:
|
||||
mcp_config.server_url = config.get("server_url")
|
||||
mcp_config.connection_config = config.get("connection_config", {})
|
||||
mcp_config.available_tools = config.get("available_tools", [])
|
||||
if config.get("source_channel") is not None:
|
||||
mcp_config.source_channel = config.get("source_channel")
|
||||
if config.get("market_id") is not None:
|
||||
mcp_config.market_id = config.get("market_id")
|
||||
if config.get("market_config_id") is not None:
|
||||
mcp_config.market_config_id = config.get("market_config_id")
|
||||
if config.get("mcp_service_id") is not None:
|
||||
mcp_config.mcp_service_id = config.get("mcp_service_id")
|
||||
|
||||
@staticmethod
|
||||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||||
@@ -1149,6 +1217,27 @@ class ToolService:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""根据配置参数直接测试MCP连接(创建前调用,无需已存在的工具记录)"""
|
||||
server_url = config.get("server_url")
|
||||
if not server_url:
|
||||
return {"success": False, "message": "server_url不能为空"}
|
||||
connection_config = config.get("connection_config") or {}
|
||||
try:
|
||||
test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config)
|
||||
if not test_result["success"]:
|
||||
return test_result
|
||||
success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config)
|
||||
if not success_flag:
|
||||
return {"success": False, "message": f"获取工具列表失败: {error}"}
|
||||
tool_list = [
|
||||
{tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}}
|
||||
for tool in tools if tool.get("name")
|
||||
]
|
||||
return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"连接测试异常: {str(e)}"}
|
||||
|
||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试MCP连接并自动同步工具列表"""
|
||||
try:
|
||||
|
||||
@@ -458,7 +458,7 @@ class WorkflowService:
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id),
|
||||
file_id=str(file.upload_file_id) if file.upload_file_id else None,
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
|
||||
@@ -2,11 +2,11 @@ import datetime
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
@@ -30,17 +30,15 @@ from app.schemas.workspace_schema import (
|
||||
WorkspaceModelsUpdate,
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def switch_workspace(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""切换工作空间"""
|
||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -60,31 +58,32 @@ def switch_workspace(
|
||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def delete_workspace_member(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
member_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
def delete_workspace_member(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
member_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}",
|
||||
BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
@@ -102,19 +101,19 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
"""
|
||||
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
|
||||
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||
|
||||
|
||||
# Ensure each neo4j workspace has a default memory config
|
||||
for workspace in workspaces:
|
||||
if workspace.storage_type == 'neo4j':
|
||||
_ensure_default_memory_config(db, workspace)
|
||||
_ensure_default_ontology_scenes(db, workspace)
|
||||
|
||||
|
||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
|
||||
def _create_workspace_only(
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
) -> Workspace:
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||
|
||||
@@ -130,6 +129,7 @@ def _create_workspace_only(
|
||||
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_workspace(
|
||||
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
|
||||
) -> Workspace:
|
||||
@@ -137,9 +137,14 @@ def create_workspace(
|
||||
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
|
||||
f"storage_type: {workspace.storage_type}"
|
||||
)
|
||||
llm=workspace.llm
|
||||
embedding=workspace.embedding
|
||||
rerank=workspace.rerank
|
||||
if workspace_repository.get_workspaces_by_name(db=db, name=workspace.name, tenant_id=user.tenant_id):
|
||||
raise BusinessException(
|
||||
message="同名工作空间已存在",
|
||||
code=BizCode.RESOURCE_ALREADY_EXISTS
|
||||
)
|
||||
llm = workspace.llm
|
||||
embedding = workspace.embedding
|
||||
rerank = workspace.rerank
|
||||
try:
|
||||
# Create the workspace without adding any members
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||
@@ -158,26 +163,26 @@ def create_workspace(
|
||||
success, error_msg = initializer.initialize_default_scenes(
|
||||
db_workspace.id, language=language
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
business_logger.info(
|
||||
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
|
||||
)
|
||||
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
|
||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.config.default_ontology_config import (
|
||||
ONLINE_EDUCATION_SCENE,
|
||||
ONLINE_EDUCATION_SCENE,
|
||||
EMOTIONAL_COMPANION_SCENE,
|
||||
get_scene_name
|
||||
)
|
||||
|
||||
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
|
||||
|
||||
# 优先尝试获取教育场景
|
||||
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
||||
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
|
||||
|
||||
|
||||
if education_scene:
|
||||
default_scene_id = education_scene.scene_id
|
||||
default_scene_name = education_scene.scene_name
|
||||
@@ -188,7 +193,7 @@ def create_workspace(
|
||||
# 如果教育场景不存在,尝试获取情感陪伴场景
|
||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
|
||||
|
||||
|
||||
if companion_scene:
|
||||
default_scene_id = companion_scene.scene_id
|
||||
default_scene_name = companion_scene.scene_name
|
||||
@@ -255,10 +260,10 @@ def create_workspace(
|
||||
avatar='',
|
||||
type=KnowledgeType.General,
|
||||
permission_id=PermissionType.Memory,
|
||||
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
|
||||
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
|
||||
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
embedding_id=embedding,
|
||||
reranker_id=rerank,
|
||||
llm_id=llm,
|
||||
image2text_id=llm,
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 256,
|
||||
@@ -293,7 +298,7 @@ def create_workspace(
|
||||
business_logger.info(
|
||||
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
|
||||
)
|
||||
|
||||
|
||||
return db_workspace
|
||||
|
||||
except Exception as e:
|
||||
@@ -303,11 +308,11 @@ def create_workspace(
|
||||
|
||||
|
||||
def update_workspace(
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
) -> Workspace:
|
||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
||||
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
|
||||
try:
|
||||
# 更新工作空间
|
||||
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
|
||||
@@ -327,7 +332,7 @@ def update_workspace(
|
||||
|
||||
|
||||
def get_workspace_members(
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
) -> List[WorkspaceMember]:
|
||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
@@ -371,7 +376,6 @@ def get_workspace_members(
|
||||
return members
|
||||
|
||||
|
||||
|
||||
# ==================== 邀请相关服务方法 ====================
|
||||
|
||||
def _generate_invite_token() -> tuple[str, str]:
|
||||
@@ -464,13 +468,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
|
||||
|
||||
def create_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
) -> WorkspaceInviteResponse:
|
||||
"""创建工作空间邀请"""
|
||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
business_logger.info(
|
||||
f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
@@ -533,17 +538,18 @@ def create_workspace_invite(
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
business_logger.error(
|
||||
f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_invites(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[WorkspaceInviteResponse]:
|
||||
"""获取工作空间邀请列表"""
|
||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
@@ -604,9 +610,9 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
|
||||
|
||||
def accept_workspace_invite(
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""接受工作空间邀请"""
|
||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||
@@ -694,7 +700,8 @@ def accept_workspace_invite(
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
business_logger.info(
|
||||
f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
|
||||
return {
|
||||
"message": "Successfully joined the workspace",
|
||||
@@ -709,13 +716,14 @@ def accept_workspace_invite(
|
||||
|
||||
|
||||
def revoke_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""撤销工作空间邀请"""
|
||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
business_logger.info(
|
||||
f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
@@ -744,13 +752,14 @@ def revoke_workspace_invite(
|
||||
|
||||
|
||||
def update_workspace_member_roles(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
user: User,
|
||||
) -> List[WorkspaceMember]:
|
||||
"""更新工作空间成员角色"""
|
||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
business_logger.info(
|
||||
f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
|
||||
# 检查管理员权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
@@ -764,7 +773,8 @@ def update_workspace_member_roles(
|
||||
for upd in updates:
|
||||
# 检查成员是否存在
|
||||
if upd.id not in member_map:
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}",
|
||||
BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
member = member_map[upd.id]
|
||||
|
||||
@@ -916,10 +926,10 @@ def get_workspace_models_configs(
|
||||
|
||||
|
||||
def update_workspace_models_configs(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
user: User,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
user: User,
|
||||
) -> Workspace:
|
||||
"""更新工作空间的模型配置(llm, embedding, rerank)
|
||||
|
||||
@@ -966,6 +976,126 @@ def update_workspace_models_configs(
|
||||
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def _fill_workspace_configs_model_defaults(
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
) -> None:
|
||||
"""Fill empty model fields for all memory configs in a workspace.
|
||||
|
||||
Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id
|
||||
if they are None, using the corresponding workspace default models.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace containing default model settings
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
# Get all configs for this workspace
|
||||
configs = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id
|
||||
).all()
|
||||
|
||||
if not configs:
|
||||
return
|
||||
|
||||
# Map of memory_config field -> workspace field
|
||||
model_field_mappings = [
|
||||
("llm_id", "llm"),
|
||||
("embedding_id", "embedding"),
|
||||
("rerank_id", "rerank"),
|
||||
("reflection_model_id", "llm"), # reflection uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
]
|
||||
|
||||
configs_updated = 0
|
||||
|
||||
for memory_config in configs:
|
||||
updated_fields = []
|
||||
|
||||
for config_field, workspace_field in model_field_mappings:
|
||||
config_value = getattr(memory_config, config_field, None)
|
||||
workspace_value = getattr(workspace, workspace_field, None)
|
||||
|
||||
if not config_value and workspace_value:
|
||||
setattr(memory_config, config_field, workspace_value)
|
||||
updated_fields.append(config_field)
|
||||
|
||||
if updated_fields:
|
||||
configs_updated += 1
|
||||
business_logger.debug(
|
||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||
)
|
||||
|
||||
if configs_updated > 0:
|
||||
try:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"Failed to update memory configs in workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
pruning_scene_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace_id: The workspace ID
|
||||
workspace_name: The workspace name (used for config naming)
|
||||
llm_id: Optional LLM model ID
|
||||
embedding_id: Optional embedding model ID
|
||||
rerank_id: Optional rerank model ID
|
||||
scene_id: Optional ontology scene ID (默认关联教育场景)
|
||||
pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
config_desc="工作空间创建时自动生成的默认记忆配置",
|
||||
workspace_id=workspace_id,
|
||||
llm_id=str(llm_id) if llm_id else None,
|
||||
embedding_id=str(embedding_id) if embedding_id else None,
|
||||
rerank_id=str(rerank_id) if rerank_id else None,
|
||||
scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景)
|
||||
pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
db.add(default_config)
|
||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"config_id": str(config_id),
|
||||
"config_name": default_config.config_name,
|
||||
"scene_id": str(scene_id) if scene_id else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ==================== 检查配置相关服务 ====================
|
||||
|
||||
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has a default memory config, creating one if missing.
|
||||
|
||||
@@ -976,19 +1106,19 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
workspace: The workspace to check
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
|
||||
# Check if default config exists for this workspace
|
||||
existing_default = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id,
|
||||
MemoryConfig.is_default == True
|
||||
).first()
|
||||
|
||||
|
||||
if not existing_default:
|
||||
# No default config exists, create one
|
||||
business_logger.info(
|
||||
f"Workspace {workspace.id} missing default memory config, creating one"
|
||||
)
|
||||
|
||||
|
||||
# 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景
|
||||
default_scene_id = None
|
||||
try:
|
||||
@@ -998,7 +1128,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
EMOTIONAL_COMPANION_SCENE,
|
||||
get_scene_name
|
||||
)
|
||||
|
||||
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
# 尝试中文和英文场景名称
|
||||
for language in ["zh", "en"]:
|
||||
@@ -1011,7 +1141,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# 如果教育场景不存在,尝试情感陪伴场景
|
||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
|
||||
@@ -1025,7 +1155,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.warning(
|
||||
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
_create_default_memory_config(
|
||||
db=db,
|
||||
@@ -1040,75 +1170,11 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.error(
|
||||
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Fill empty model fields for ALL configs in this workspace
|
||||
_fill_workspace_configs_model_defaults(db, workspace)
|
||||
|
||||
|
||||
def _fill_workspace_configs_model_defaults(
|
||||
db: Session,
|
||||
workspace: Workspace
|
||||
) -> None:
|
||||
"""Fill empty model fields for all memory configs in a workspace.
|
||||
|
||||
Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id
|
||||
if they are None, using the corresponding workspace default models.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace containing default model settings
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
# Get all configs for this workspace
|
||||
configs = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.workspace_id == workspace.id
|
||||
).all()
|
||||
|
||||
if not configs:
|
||||
return
|
||||
|
||||
# Map of memory_config field -> workspace field
|
||||
model_field_mappings = [
|
||||
("llm_id", "llm"),
|
||||
("embedding_id", "embedding"),
|
||||
("rerank_id", "rerank"),
|
||||
("reflection_model_id", "llm"), # reflection uses LLM
|
||||
("emotion_model_id", "llm"), # emotion uses LLM
|
||||
]
|
||||
|
||||
configs_updated = 0
|
||||
|
||||
for memory_config in configs:
|
||||
updated_fields = []
|
||||
|
||||
for config_field, workspace_field in model_field_mappings:
|
||||
config_value = getattr(memory_config, config_field, None)
|
||||
workspace_value = getattr(workspace, workspace_field, None)
|
||||
|
||||
if not config_value and workspace_value:
|
||||
setattr(memory_config, config_field, workspace_value)
|
||||
updated_fields.append(config_field)
|
||||
|
||||
if updated_fields:
|
||||
configs_updated += 1
|
||||
business_logger.debug(
|
||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||
)
|
||||
|
||||
if configs_updated > 0:
|
||||
try:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"Failed to update memory configs in workspace {workspace.id}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||
|
||||
@@ -1153,57 +1219,3 @@ def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
business_logger.error(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
scene_id: Optional[uuid.UUID] = None,
|
||||
pruning_scene_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace_id: The workspace ID
|
||||
workspace_name: The workspace name (used for config naming)
|
||||
llm_id: Optional LLM model ID
|
||||
embedding_id: Optional embedding model ID
|
||||
rerank_id: Optional rerank model ID
|
||||
scene_id: Optional ontology scene ID (默认关联教育场景)
|
||||
pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
config_desc="工作空间创建时自动生成的默认记忆配置",
|
||||
workspace_id=workspace_id,
|
||||
llm_id=str(llm_id) if llm_id else None,
|
||||
embedding_id=str(embedding_id) if embedding_id else None,
|
||||
rerank_id=str(rerank_id) if rerank_id else None,
|
||||
scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景)
|
||||
pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
db.add(default_config)
|
||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"config_id": str(config_id),
|
||||
"config_name": default_config.config_name,
|
||||
"scene_id": str(scene_id) if scene_id else None,
|
||||
}
|
||||
)
|
||||
|
||||
376
api/app/tasks.py
376
api/app/tasks.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -14,6 +15,62 @@ from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
_sync_redis_pool: redis.ConnectionPool = None
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||
global _sync_redis_pool
|
||||
if _sync_redis_pool is None:
|
||||
try:
|
||||
_sync_redis_pool = redis.ConnectionPool(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=10,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool created for Celery tasks")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Redis connection pool: {e}", exc_info=True)
|
||||
return None
|
||||
return _sync_redis_pool
|
||||
|
||||
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
"""获取同步 Redis 客户端(使用连接池)
|
||||
|
||||
使用连接池提供的客户端,支持自动重连和健康检查。
|
||||
如果 Redis 不可用,返回 None,调用方应优雅降级。
|
||||
|
||||
Returns:
|
||||
redis.StrictRedis: Redis 客户端实例,如果连接失败则返回 None
|
||||
"""
|
||||
try:
|
||||
pool = _get_or_create_redis_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
client = redis.StrictRedis(connection_pool=pool)
|
||||
# 验证连接可用性
|
||||
client.ping()
|
||||
return client
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis connection failed: {e}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
@@ -1090,6 +1147,22 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
|
||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||
try:
|
||||
_r = get_sync_redis_client()
|
||||
if _r is not None:
|
||||
from datetime import timedelta as _td
|
||||
from datetime import timezone as _tz
|
||||
_CST = _tz(_td(hours=8))
|
||||
_now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat()
|
||||
_r.set(
|
||||
f"write_message:last_done:{end_user_id}",
|
||||
_now_cst,
|
||||
ex=86400 * 30,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"result": result,
|
||||
@@ -2149,12 +2222,16 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from sqlalchemy import select, func
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
TimeFilterUnavailableError,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
@@ -2167,18 +2244,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有已存储数据的用户ID(分批次处理)
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
|
||||
# 先统计总数用于日志
|
||||
from sqlalchemy import func
|
||||
total_users = db.execute(
|
||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||
).scalar() or 0
|
||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
||||
logger.info(f"表中存量用户总数: {total_users},开始时间轴筛选")
|
||||
|
||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
||||
# 构建 Redis 同步客户端,用于时间轴筛选
|
||||
_redis_client = get_sync_redis_client()
|
||||
|
||||
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
|
||||
# Redis 不可用时回退到全量处理
|
||||
try:
|
||||
refresh_iter = repo.get_users_needing_refresh(_redis_client, batch_size=100)
|
||||
except TimeFilterUnavailableError as e:
|
||||
logger.warning(f"时间轴筛选不可用,回退到全量刷新: {e}")
|
||||
refresh_iter = repo.get_all_user_ids(batch_size=100)
|
||||
|
||||
for end_user_id in refresh_iter:
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
@@ -2264,10 +2350,10 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_results.append(error_info)
|
||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
||||
# ---- 当天新增用户兜底初始化 ----
|
||||
new_users_initialized = 0
|
||||
new_users_failed = 0
|
||||
logger.info("开始处理当天新增的增量用户初始化")
|
||||
logger.info("开始处理当天新增用户的兜底初始化")
|
||||
|
||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||
@@ -2281,35 +2367,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
errors.append(f"隐性记忆初始化失败: {str(e)}")
|
||||
logger.error(f"新用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
errors.append(f"情绪建议初始化失败: {str(e)}")
|
||||
logger.error(f"新用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_success or emotion_success:
|
||||
new_users_initialized += 1
|
||||
@@ -2319,7 +2397,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"type": "new_user_init",
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
@@ -2331,7 +2409,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"type": "new_user_init",
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
@@ -2339,27 +2417,24 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
})
|
||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
||||
)
|
||||
# ---- 增量用户处理结束 ----
|
||||
logger.info(f"当天新增用户兜底初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}")
|
||||
# ---- 新增用户兜底初始化结束 ----
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||
f"存量用户总数={total_users}, "
|
||||
f"隐性记忆成功={successful_implicit}, "
|
||||
f"情绪建议成功={successful_emotion}, "
|
||||
f"存量失败={failed}, "
|
||||
f"增量初始化成功={new_users_initialized}, "
|
||||
f"增量初始化失败={new_users_failed}"
|
||||
f"新增用户初始化成功={new_users_initialized}, "
|
||||
f"新增用户初始化失败={new_users_failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": (
|
||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
f"当天新增用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
@@ -2367,7 +2442,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"failed": failed,
|
||||
"new_users_initialized": new_users_initialized,
|
||||
"new_users_failed": new_users_failed,
|
||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
||||
"user_results": user_results[:50]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -2416,3 +2491,232 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_implicit_emotions_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
|
||||
triggered=True,
|
||||
)
|
||||
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
|
||||
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
with get_db_context() as db:
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
existing = repo.get_by_end_user_id(end_user_id)
|
||||
if existing is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
|
||||
implicit_ok = False
|
||||
emotion_ok = False
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_ok or emotion_ok:
|
||||
initialized += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
|
||||
|
||||
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_interest_distribution_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
language = "zh"
|
||||
|
||||
service = MemoryAgentService()
|
||||
|
||||
with get_db_context() as db:
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
|
||||
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,38 @@
|
||||
{
|
||||
"v0.2.7": {
|
||||
"introduction": {
|
||||
"codeName": "武陵",
|
||||
"releaseDate": "2026-3-13",
|
||||
"upgradePosition": "🐻 应用可移植性、工具生态扩展与记忆智能精细化",
|
||||
"coreUpgrades": [
|
||||
"1. 应用管理与可移植性<br>* 应用导入/导出:全面支持 Agent 配置和工作流定义的导入导出,实现跨环境无缝迁移、备份和共享",
|
||||
"2. 工具生态扩展 🔌<br>* MCP 广场集成:工具管理接入 MCP 广场,提供集中式工具发现、浏览和集成枢纽",
|
||||
"3. 工作流增强 📝<br>* 备注节点:新增备注节点类型,支持工作流图中的内联文档和上下文说明,提升协作效率",
|
||||
"4. 记忆智能精细化 🧠<br>* 隐性记忆与情绪记忆生成逻辑优化:含数据存在性校验、时间轴筛选和兴趣分布缓存校验<br>* 兴趣分布生成逻辑改进:优化算法产生更准确的用户兴趣画像",
|
||||
"5. 用户体验改进 🎨<br>* 知识库分享加载状态:增加加载指示器,改善感知响应速度",
|
||||
"6. 稳健性与缺陷修复 🔧<br>* 应用调试终端用户管理:修复调试会话错误创建 end_user 记录问题<br>* 知识库数据集创建流程:解决创建数据集后无法进入下一步的缺陷<br>* RAG 空间记忆生成失败:修复记忆生成失败和存储中断的关键问题<br>* 应用字符限制强制执行:增加条件校验防止过长输入<br>* 语义剪枝情绪/兴趣保留:优化剪枝逻辑防止误删情绪和兴趣片段<br>* 语义剪枝效果优化:增强算法平衡记忆压缩与信息保留",
|
||||
"<br>",
|
||||
"v0.2.8 及更远的未来将引入多模态记忆能力,实现知识库和模型的分服务部署,为应用增加语音输入支持,并扩展应用能力至语音回复、BI 可视化、PPT 生成和直接生图。应用会话分享和联网搜索功能将得到修复和增强。记忆检索基准测试和情景记忆聚类算法将增强上下文召回和时序推理能力。通往真正智能、多模态、上下文感知应用的旅程仍在继续。",
|
||||
"记忆熊,智慧致远 🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "WuLing",
|
||||
"releaseDate": "2026-3-13",
|
||||
"upgradePosition": "🐻 Application portability, tool ecosystem expansion, and memory intelligence refinement",
|
||||
"coreUpgrades": [
|
||||
"1. Application Management & Portability<br>* Application Import/Export: Full support for importing and exporting agent configurations and workflow definitions, enabling seamless cross-environment migration, backup, and sharing",
|
||||
"2. Tool Ecosystem Expansion 🔌<br>* MCP Marketplace Integration: Tool management now includes MCP Marketplace access for centralized tool discovery, browsing, and integration",
|
||||
"3. Workflow Enhancements 📝<br>* Annotation Node: Introduced annotation node type for inline documentation and contextual notes within workflow graphs, improving collaboration",
|
||||
"4. Memory Intelligence Refinement 🧠<br>* Implicit & Emotional Memory Generation Logic: Comprehensive optimization including data existence validation, timeline filtering, and interest distribution cache validation<br>* Interest Distribution Generation Logic: Refined algorithm for more accurate user interest profiles",
|
||||
"5. User Experience Improvements 🎨<br>* Knowledge Base Sharing Loading State: Added loading indicators to improve perceived responsiveness",
|
||||
"6. Robustness & Bug Fixes 🔧<br>* End User Management in App Debugging: Fixed incorrect end_user record creation during debugging sessions<br>* Knowledge Base Dataset Creation Flow: Resolved bug preventing next step after dataset creation<br>* RAG Space Memory Generation Failure: Fixed critical memory generation and storage interruption issue<br>* Application Character Limit Enforcement: Added conditional validation to prevent excessively long input<br>* Semantic Pruning Emotion/Interest Preservation: Optimized pruning logic to prevent incorrect deletion of emotional and interest fragments<br>* Semantic Pruning Effectiveness: Enhanced algorithm balance between memory compression and information retention",
|
||||
"<br>",
|
||||
"Looking forward to v0.2.8 and beyond, we will introduce multimodal memory capabilities with distributed service deployment for knowledge bases and models, enabling voice input for applications and expanding application capabilities with voice responses, BI visualizations, PPT generation, and direct image creation. Application conversation sharing and web search functionality will be restored and enhanced. Memory retrieval benchmarking and episodic memory clustering algorithms will enhance contextual recall and temporal reasoning. The journey toward truly intelligent, multimodal, context-aware applications continues.",
|
||||
"MemoryBear, Wisdom Reaching Far 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.6": {
|
||||
"introduction": {
|
||||
"codeName": "听剑",
|
||||
|
||||
@@ -49,7 +49,7 @@ services:
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency)
|
||||
# Periodic worker - Scheduled/beat tasks + API-triggered tasks (prefork, low concurrency)
|
||||
worker-periodic:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: worker-periodic
|
||||
|
||||
@@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
ENABLE_SINGLE_SESSION=
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
||||
MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024
|
||||
FILE_PATH=/files
|
||||
|
||||
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"
|
||||
|
||||
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""202603061644
|
||||
|
||||
Revision ID: 1ac07dc7366f
|
||||
Revises: 6a4641cf192b
|
||||
Create Date: 2026-03-06 16:51:10.152305
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1ac07dc7366f'
|
||||
down_revision: Union[str, None] = '6a4641cf192b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('mcp_tool_configs', sa.Column('source_channel', sa.String(length=50), server_default=sa.text("'self_hosted'"), nullable=False, comment='来源渠道'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('market_id', sa.UUID(), nullable=True, comment='渠道市场id'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('market_config_id', sa.UUID(), nullable=True, comment='渠道市场配置id'))
|
||||
op.add_column('mcp_tool_configs', sa.Column('mcp_service_id', sa.String(length=255), nullable=True, comment='mcp服务id'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('mcp_tool_configs', 'mcp_service_id')
|
||||
op.drop_column('mcp_tool_configs', 'market_config_id')
|
||||
op.drop_column('mcp_tool_configs', 'market_id')
|
||||
op.drop_column('mcp_tool_configs', 'source_channel')
|
||||
# ### end Alembic commands ###
|
||||
34
api/migrations/versions/fb834419b18f_202603101453.py
Normal file
34
api/migrations/versions/fb834419b18f_202603101453.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""202603101453
|
||||
|
||||
Revision ID: fb834419b18f
|
||||
Revises: 1ac07dc7366f
|
||||
Create Date: 2026-03-10 14:46:48.038643
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fb834419b18f'
|
||||
down_revision: Union[str, None] = '1ac07dc7366f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('end_users', sa.Column('rag_tags', sa.Text(), nullable=True, comment='RAG模式下提取的标签列表(JSON格式)'))
|
||||
op.add_column('end_users', sa.Column('rag_personas', sa.Text(), nullable=True, comment='RAG模式下提取的人物形象列表(JSON格式)'))
|
||||
op.add_column('end_users', sa.Column('rag_summary_updated_at', sa.DateTime(), nullable=True, comment='RAG摘要/标签/人物形象最后更新时间'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('end_users', 'rag_summary_updated_at')
|
||||
op.drop_column('end_users', 'rag_personas')
|
||||
op.drop_column('end_users', 'rag_tags')
|
||||
# ### end Alembic commands ###
|
||||
@@ -145,6 +145,8 @@ dependencies = [
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
"modelscope>=1.34.0",
|
||||
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
||||
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
Submodule redbear-mem-benchmark updated: 8494e82498...c3bbc6931c
@@ -43,6 +43,7 @@
|
||||
"i18next": "^25.6.0",
|
||||
"js-yaml": "^4.1.1",
|
||||
"lexical": "^0.39.0",
|
||||
"mammoth": "^1.12.0",
|
||||
"mermaid": "^11.12.1",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
@@ -58,6 +59,7 @@
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"tailwindcss": "^4.1.14",
|
||||
"xlsx": "^0.18.5",
|
||||
"zustand": "^5.0.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -135,4 +135,12 @@ export const getExperienceConfig = (share_token: string) => {
|
||||
'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}`
|
||||
}
|
||||
})
|
||||
}
|
||||
// Export application
|
||||
export const appExport = (app_id: string, appName: string, data?: { release_version: string }) => {
|
||||
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
|
||||
}
|
||||
// Import application
|
||||
export const appImport = (formData: FormData) => {
|
||||
return request.uploadFile(`/apps/import`, formData)
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 10:58:41
|
||||
* @Last Modified time: 2026-03-12 18:25:06
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type {
|
||||
@@ -118,8 +118,9 @@ export const getChunkInsight = (end_user_id: string) => {
|
||||
return request.get(`/dashboard/chunk_insight`, { end_user_id })
|
||||
}
|
||||
// RAG User Memory - Storage content
|
||||
export const getRagContent = (end_user_id: string) => {
|
||||
return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 })
|
||||
export const getRagContentUrl = '/dashboard/rag_content'
|
||||
export const getRagContent = (end_user_id: string, page = 1, pagesize = 20) => {
|
||||
return request.get(getRagContentUrl, { end_user_id, page, pagesize })
|
||||
}
|
||||
// Emotion distribution analysis
|
||||
export const getWordCloud = (end_user_id: string) => {
|
||||
@@ -224,6 +225,10 @@ export const getConversationDetail = (end_user_id: string, conversation_id: stri
|
||||
export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => {
|
||||
return request.post(`/memory/forget-memory/trigger`, data)
|
||||
}
|
||||
// RAG type - Refresh RAG user summary and memory insight
|
||||
export const generateRagProfile = (end_user_id: string) => {
|
||||
return request.post(`/dashboard/generate_rag_profile`, { end_user_id })
|
||||
}
|
||||
/*************** end User Memory APIs ******************************/
|
||||
|
||||
/****************** Memory Management APIs *******************************/
|
||||
|
||||
@@ -41,12 +41,12 @@ export const deleteCompositeModel = (model_id: string) => {
|
||||
return request.delete(`/models/composite/${model_id}`)
|
||||
}
|
||||
// Create API keys for all matching models by provider
|
||||
export const updateProviderApiKeys = (data: KeyConfigModalForm) => {
|
||||
return request.post('/models/provider/apikeys', data)
|
||||
export const updateProviderApiKeys = (data: KeyConfigModalForm, signal?: AbortSignal) => {
|
||||
return request.post('/models/provider/apikeys', data, { signal })
|
||||
}
|
||||
// Create model API key
|
||||
export const addModelApiKey = (model_id: string, data: MultiKeyForm) => {
|
||||
return request.post(`/models/${model_id}/apikeys`, data)
|
||||
export const addModelApiKey = (model_id: string, data: MultiKeyForm, signal?: AbortSignal) => {
|
||||
return request.post(`/models/${model_id}/apikeys`, data, { signal })
|
||||
}
|
||||
// Delete model API key
|
||||
export const deleteModelApiKey = (api_key_id: string) => {
|
||||
@@ -65,10 +65,10 @@ export const addModelPlaza = (model_base_id: string) => {
|
||||
return request.post(`/models/model_plaza/${model_base_id}/add`)
|
||||
}
|
||||
// Create custom model
|
||||
export const addCustomModel = (data: CustomModelForm) => {
|
||||
return request.post('/models', data)
|
||||
export const addCustomModel = (data: CustomModelForm, signal?: AbortSignal) => {
|
||||
return request.post('/models', data, { signal })
|
||||
}
|
||||
// Update custom model
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
|
||||
return request.put(`/models/${model_base_id}`, data)
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm, signal?: AbortSignal) => {
|
||||
return request.put(`/models/${model_base_id}`, data, { signal })
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
import { request } from '@/utils/request'
|
||||
import type { Query, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types'
|
||||
import type { Query, MarketQuery, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types'
|
||||
|
||||
// 工具列表
|
||||
export const getTools = (data: Query) => {
|
||||
return request.get('/tools', data)
|
||||
}
|
||||
// 创建MCP工具
|
||||
export const addTool = (values: MCPToolItem | CustomToolItem) => {
|
||||
return request.post('/tools', values)
|
||||
export const addTool = (values: MCPToolItem | CustomToolItem, config?: { signal?: AbortSignal }) => {
|
||||
return request.post('/tools', values, config)
|
||||
}
|
||||
// 更新工具
|
||||
export const updateTool = (tool_id: string, data: MCPToolItem | InnerToolItem | CustomToolItem) => {
|
||||
return request.put(`/tools/${tool_id}`, data)
|
||||
export const updateTool = (tool_id: string, data: MCPToolItem | InnerToolItem | CustomToolItem, config?: { signal?: AbortSignal }) => {
|
||||
return request.put(`/tools/${tool_id}`, data, config)
|
||||
}
|
||||
// 删除工具
|
||||
export const deleteTool = (tool_id: string) => {
|
||||
@@ -33,4 +33,44 @@ export const getToolDetail = (tool_id: string) => {
|
||||
}
|
||||
export const getToolMethods = (tool_id: string) => {
|
||||
return request.get(`/tools/${tool_id}/methods`)
|
||||
}
|
||||
|
||||
// MCP市场列表
|
||||
export const getMarketTools = (data?: Record<string, any>) => {
|
||||
return request.get('/mcp_markets/mcp_markets', data)
|
||||
}
|
||||
// 市场配置创建
|
||||
export const createMarketConfig = (values: {
|
||||
mcp_market_id: string;
|
||||
token: string;
|
||||
status: number;
|
||||
}) => {
|
||||
return request.post('/mcp_market_configs/mcp_market_config', values)
|
||||
}
|
||||
// 市场配置更新
|
||||
export const updateMarketConfig = (values: {
|
||||
mcp_market_config_id: string;
|
||||
token: string;
|
||||
status: number;
|
||||
}) => {
|
||||
return request.put(`/mcp_market_configs/${values.mcp_market_config_id}`, values)
|
||||
}
|
||||
// 市场根据id获取配置
|
||||
export const getMarketConfig = (mcp_market_id: string) => {
|
||||
return request.get(`/mcp_market_configs/mcp_market_id/${mcp_market_id}`)
|
||||
}
|
||||
// 市场MCP列表
|
||||
export const getMarketMCPs = (data: MarketQuery) => {
|
||||
return request.get('/mcp_market_configs/mcp_servers', data)
|
||||
}
|
||||
// 根据配置ID serverId 获取MCP服务详情
|
||||
export const getMarketMCPDetail = (data:{
|
||||
mcp_market_config_id: string;
|
||||
server_id: string;
|
||||
}) => {
|
||||
return request.get(`/mcp_market_configs/mcp_server`,data)
|
||||
}
|
||||
// 市场已激活MCP列表
|
||||
export const getMarketMCPsActivated = (data: MarketQuery) => {
|
||||
return request.get('/mcp_market_configs/operational_mcp_servers', data)
|
||||
}
|
||||
@@ -1,20 +1,18 @@
|
||||
import { useState, useEffect, type FC } from 'react';
|
||||
import { Spin, Alert, Button } from 'antd';
|
||||
import { ReloadOutlined } from '@ant-design/icons';
|
||||
import { Spin, Alert, Button, Table } from 'antd';
|
||||
import { ReloadOutlined, DownloadOutlined } from '@ant-design/icons';
|
||||
import RbMarkdown from '../Markdown';
|
||||
import { cookieUtils } from '@/utils/request'
|
||||
|
||||
type PreviewMode = 'office' | 'google';
|
||||
import { cookieUtils } from '@/utils/request';
|
||||
import mammoth from 'mammoth';
|
||||
import * as XLSX from 'xlsx';
|
||||
|
||||
interface DocumentPreviewProps {
|
||||
fileUrl: string;
|
||||
fileName?: string;
|
||||
fileExt?: string; // 文件扩展名(优先使用)
|
||||
fileExt?: string;
|
||||
width?: string | number;
|
||||
height?: string | number;
|
||||
className?: string;
|
||||
mode?: PreviewMode; // 预览模式
|
||||
showModeSwitch?: boolean; // 是否显示模式切换按钮
|
||||
}
|
||||
|
||||
const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
@@ -24,18 +22,19 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
width = '100%',
|
||||
height = '600px',
|
||||
className = '',
|
||||
mode = 'office',
|
||||
showModeSwitch = true,
|
||||
}) => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState(false);
|
||||
const [currentMode, setCurrentMode] = useState<PreviewMode>(mode);
|
||||
const [errorMessage, setErrorMessage] = useState<string>('');
|
||||
const [textContent, setTextContent] = useState<string>('');
|
||||
const [htmlContent, setHtmlContent] = useState<string>('');
|
||||
const [excelData, setExcelData] = useState<{ sheetName: string; data: any[][] }[]>([]);
|
||||
|
||||
// 支持的文件类型
|
||||
const supportedTypes = ['.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.pdf', '.txt', '.md', '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'];
|
||||
// 支持预览的文件类型
|
||||
const previewableTypes = ['.pdf', '.txt', '.md', '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.doc', '.docx', '.xls', '.xlsx'];
|
||||
// PPT 暂不支持
|
||||
const downloadOnlyTypes = ['.ppt', '.pptx'];
|
||||
|
||||
// 获取文件扩展名(优先使用 fileExt prop)
|
||||
const getFileExtension = () => {
|
||||
if (fileExt) {
|
||||
return fileExt.toLowerCase().startsWith('.') ? fileExt.toLowerCase() : `.${fileExt.toLowerCase()}`;
|
||||
@@ -45,67 +44,25 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
return match ? `.${match[1].toLowerCase()}` : '';
|
||||
};
|
||||
|
||||
// 检查是否为文本文件
|
||||
const isTextFile = () => {
|
||||
const ext = getFileExtension();
|
||||
return ext === '.txt';
|
||||
};
|
||||
|
||||
// 检查是否为 Markdown 文件
|
||||
const isMarkdownFile = () => {
|
||||
const ext = getFileExtension();
|
||||
return ext === '.md';
|
||||
};
|
||||
|
||||
// 检查是否为图片文件
|
||||
const isTextFile = () => getFileExtension() === '.txt';
|
||||
const isMarkdownFile = () => getFileExtension() === '.md';
|
||||
const isImageFile = () => {
|
||||
const ext = getFileExtension();
|
||||
const imageExts = ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'];
|
||||
return imageExts.includes(ext);
|
||||
};
|
||||
|
||||
// 检查文件类型是否支持
|
||||
const isSupportedFile = () => {
|
||||
const ext = getFileExtension();
|
||||
return ext && supportedTypes.includes(ext);
|
||||
return imageExts.includes(getFileExtension());
|
||||
};
|
||||
const isPdfFile = () => getFileExtension() === '.pdf';
|
||||
const isWordFile = () => ['.doc', '.docx'].includes(getFileExtension());
|
||||
const isExcelFile = () => ['.xls', '.xlsx'].includes(getFileExtension());
|
||||
const isPreviewable = () => previewableTypes.includes(getFileExtension());
|
||||
const isDownloadOnly = () => downloadOnlyTypes.includes(getFileExtension());
|
||||
|
||||
// 检查是否为 PDF 文件
|
||||
const isPdfFile = () => {
|
||||
const ext = getFileExtension();
|
||||
return ext === '.pdf';
|
||||
};
|
||||
|
||||
// 构建预览 URL
|
||||
const getPreviewUrl = () => {
|
||||
// 处理文件 URL,如果是完整的 URL,转换为代理路径
|
||||
let requestUrl = fileUrl;
|
||||
|
||||
// 如果是完整的 https://devapi.mem.redbearai.com 开头的 URL,提取路径部分
|
||||
// 这样可以通过代理访问,避免 CORS 问题
|
||||
if (fileUrl.includes('devapi.mem.redbearai.com')) {
|
||||
const url = new URL(fileUrl);
|
||||
requestUrl = url.pathname; // 只取路径部分,例如 /api/files/xxx
|
||||
}
|
||||
|
||||
// 对于 PDF 文件,直接使用浏览器内置预览
|
||||
if (isPdfFile()) {
|
||||
return requestUrl;
|
||||
}
|
||||
|
||||
// 确保 fileUrl 是完整的 URL(用于第三方预览服务)
|
||||
let fullUrl = fileUrl;
|
||||
if (!fileUrl.startsWith('http')) {
|
||||
fullUrl = `${window.location.origin}${fileUrl.startsWith('/') ? '' : '/'}${fileUrl}`;
|
||||
}
|
||||
console.log('预览 URL:', fullUrl);
|
||||
// 根据模式选择预览服务
|
||||
if (currentMode === 'google') {
|
||||
return `https://docs.google.com/viewer?url=${encodeURIComponent(fullUrl)}&embedded=true`;
|
||||
}
|
||||
|
||||
// 默认使用 Microsoft Office Online Viewer
|
||||
return `https://view.officeapps.live.com/op/embed.aspx?src=${encodeURIComponent(fullUrl)}`;
|
||||
const handleDownload = () => {
|
||||
const link = document.createElement('a');
|
||||
link.href = fileUrl;
|
||||
link.download = fileName || 'document';
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
};
|
||||
|
||||
const handleLoad = () => {
|
||||
@@ -113,20 +70,24 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
setError(false);
|
||||
};
|
||||
|
||||
const handleError = () => {
|
||||
const handleError = (msg?: string) => {
|
||||
setLoading(false);
|
||||
setError(true);
|
||||
if (msg) setErrorMessage(msg);
|
||||
};
|
||||
|
||||
const handleRetry = () => {
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
|
||||
if (isTextFile() || isMarkdownFile()) {
|
||||
// 重新加载文本文件
|
||||
loadTextFile();
|
||||
} else if (isWordFile()) {
|
||||
loadWordFile();
|
||||
} else if (isExcelFile()) {
|
||||
loadExcelFile();
|
||||
} else {
|
||||
// 强制重新加载 iframe
|
||||
const iframe = document.querySelector(`iframe[title="${fileName || '文档预览'}"]`) as HTMLIFrameElement;
|
||||
if (iframe) {
|
||||
iframe.src = iframe.src;
|
||||
@@ -134,82 +95,164 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
const handleSwitchMode = () => {
|
||||
setCurrentMode(prev => prev === 'office' ? 'google' : 'office');
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
};
|
||||
|
||||
// 加载文本文件内容
|
||||
const loadTextFile = async () => {
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
try {
|
||||
// 处理文件 URL,如果是完整的 URL,转换为代理路径
|
||||
let requestUrl = fileUrl;
|
||||
|
||||
// 如果是完整的 https://devapi.mem.redbearai.com 开头的 URL,提取路径部分
|
||||
if (fileUrl.includes('devapi.mem.redbearai.com')) {
|
||||
const url = new URL(fileUrl);
|
||||
requestUrl = url.pathname; // 只取路径部分,例如 /api/files/xxx
|
||||
requestUrl = url.pathname;
|
||||
}
|
||||
|
||||
const response = await fetch(requestUrl, {
|
||||
credentials: 'include', // 包含认证信息
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to load file');
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
// 检查响应的 Content-Type
|
||||
const contentType = response.headers.get('Content-Type') || '';
|
||||
console.log('文件 Content-Type:', contentType);
|
||||
|
||||
// 如果是图片类型,显示错误提示
|
||||
if (contentType.startsWith('image/')) {
|
||||
setError(true);
|
||||
setTextContent('');
|
||||
setLoading(false);
|
||||
console.error('文件实际是图片类型,但被标记为 txt');
|
||||
handleError('文件实际是图片类型,但被标记为文本文件');
|
||||
return;
|
||||
}
|
||||
|
||||
const text = await response.text();
|
||||
|
||||
// 检查是否是二进制数据(如 PNG 文件头)
|
||||
if (text.startsWith('\x89PNG') || text.startsWith('<27>PNG')) {
|
||||
setError(true);
|
||||
setTextContent('');
|
||||
setLoading(false);
|
||||
console.error('文件内容是 PNG 图片,但扩展名是 txt');
|
||||
handleError('文件内容是图片,但扩展名是文本');
|
||||
return;
|
||||
}
|
||||
|
||||
setTextContent(text);
|
||||
setLoading(false);
|
||||
} catch (err) {
|
||||
} catch (err: any) {
|
||||
console.error('加载文本文件失败:', err);
|
||||
setError(true);
|
||||
setLoading(false);
|
||||
handleError(err.message || '加载文本文件失败');
|
||||
}
|
||||
};
|
||||
|
||||
const loadWordFile = async () => {
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
try {
|
||||
let requestUrl = fileUrl;
|
||||
|
||||
if (fileUrl.includes('devapi.mem.redbearai.com')) {
|
||||
const url = new URL(fileUrl);
|
||||
requestUrl = url.pathname;
|
||||
}
|
||||
|
||||
const response = await fetch(requestUrl, {
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
const result = await mammoth.convertToHtml({ arrayBuffer });
|
||||
setHtmlContent(result.value);
|
||||
setLoading(false);
|
||||
} catch (err: any) {
|
||||
console.error('加载 Word 文件失败:', err);
|
||||
handleError(err.message || '加载 Word 文件失败,文件可能已损坏');
|
||||
}
|
||||
};
|
||||
|
||||
const loadExcelFile = async () => {
|
||||
setLoading(true);
|
||||
setError(false);
|
||||
setErrorMessage('');
|
||||
try {
|
||||
let requestUrl = fileUrl;
|
||||
|
||||
if (fileUrl.includes('devapi.mem.redbearai.com')) {
|
||||
const url = new URL(fileUrl);
|
||||
requestUrl = url.pathname;
|
||||
}
|
||||
|
||||
const response = await fetch(requestUrl, {
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
const workbook = XLSX.read(arrayBuffer, { type: 'array' });
|
||||
|
||||
const sheets = workbook.SheetNames.map(sheetName => {
|
||||
const worksheet = workbook.Sheets[sheetName];
|
||||
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||
return { sheetName, data };
|
||||
});
|
||||
|
||||
setExcelData(sheets);
|
||||
setLoading(false);
|
||||
} catch (err: any) {
|
||||
console.error('加载 Excel 文件失败:', err);
|
||||
handleError(err.message || '加载 Excel 文件失败,文件可能已损坏');
|
||||
}
|
||||
};
|
||||
|
||||
// 当文件是 txt 或 md 时,加载文本内容
|
||||
useEffect(() => {
|
||||
if (isTextFile() || isMarkdownFile()) {
|
||||
loadTextFile();
|
||||
} else if (isWordFile()) {
|
||||
loadWordFile();
|
||||
} else if (isExcelFile()) {
|
||||
loadExcelFile();
|
||||
}
|
||||
}, [fileUrl]);
|
||||
|
||||
if (!isSupportedFile()) {
|
||||
// PPT 文件只提供下载
|
||||
if (isDownloadOnly()) {
|
||||
return (
|
||||
<div className={`rb:relative rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded rb:border rb:border-gray-200 ${className}`} style={{ width, height }}>
|
||||
<Alert
|
||||
message="PowerPoint 文档预览"
|
||||
description={
|
||||
<div className="rb:text-center">
|
||||
<p className="rb:mb-4">PPT 文件暂不支持在线预览,请下载后查看</p>
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<DownloadOutlined />}
|
||||
onClick={handleDownload}
|
||||
>
|
||||
下载文件
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
type="info"
|
||||
showIcon
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isPreviewable()) {
|
||||
return (
|
||||
<Alert
|
||||
message="不支持的文件类型"
|
||||
description={`仅支持以下文件类型:${supportedTypes.join(', ')}`}
|
||||
description={`仅支持预览:${previewableTypes.join(', ')}`}
|
||||
type="warning"
|
||||
showIcon
|
||||
/>
|
||||
@@ -230,23 +273,26 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
message="预览失败"
|
||||
description={
|
||||
<div>
|
||||
<p>无法加载文档预览,可能的原因:</p>
|
||||
<ul className="rb:list-disc rb:pl-5 rb:mt-2">
|
||||
<li>文件需要认证访问,Office 预览服务无法访问</li>
|
||||
<li>文件 URL 无法公开访问(需要配置公开访问或临时签名 URL)</li>
|
||||
<li>文件大小超过限制(Office 预览通常限制 10MB)</li>
|
||||
<li>预览服务暂时不可用</li>
|
||||
<p className="rb:mb-2">无法加载文档预览</p>
|
||||
{errorMessage && (
|
||||
<p className="rb:text-sm rb:text-red-600 rb:mb-3">
|
||||
错误详情:{errorMessage}
|
||||
</p>
|
||||
)}
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:mb-3">可能的原因:</p>
|
||||
<ul className="rb:list-disc rb:pl-5 rb:text-sm rb:text-gray-600 rb:mb-3">
|
||||
<li>文件 URL 无法访问(401/403/404)</li>
|
||||
<li>认证 token 已过期</li>
|
||||
<li>文件格式损坏或不匹配</li>
|
||||
<li>网络连接问题</li>
|
||||
</ul>
|
||||
<p className="rb:mt-2 rb:text-gray-600">建议:请下载文件到本地查看</p>
|
||||
<div className="rb:mt-4 rb:flex rb:gap-2">
|
||||
<Button icon={<ReloadOutlined />} onClick={handleRetry}>
|
||||
重试
|
||||
</Button>
|
||||
{showModeSwitch && !isPdfFile() && (
|
||||
<Button onClick={handleSwitchMode}>
|
||||
切换到 {currentMode === 'office' ? 'Google' : 'Office'} 预览
|
||||
</Button>
|
||||
)}
|
||||
<Button icon={<DownloadOutlined />} onClick={handleDownload}>
|
||||
下载文件
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
@@ -256,26 +302,23 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 图片文件预览 */}
|
||||
{isImageFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:h-full rb:overflow-auto rb:bg-gray-50 rb:flex rb:items-center rb:justify-center">
|
||||
<img
|
||||
src={fileUrl}
|
||||
alt={fileName || '图片预览'}
|
||||
className="rb:max-w-full rb:max-h-full rb:object-contain"
|
||||
onError={() => setError(true)}
|
||||
onError={() => handleError('图片加载失败')}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Markdown 文件预览 */}
|
||||
{isMarkdownFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:h-full rb:overflow-auto rb:bg-white rb:p-6 rb:rounded rb:border rb:border-gray-200">
|
||||
<RbMarkdown content={textContent} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 文本文件预览 */}
|
||||
{isTextFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:h-full rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
|
||||
<pre className="rb:whitespace-pre-wrap rb:text-sm rb:text-gray-800 rb:font-mono">
|
||||
@@ -284,44 +327,52 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* PDF 文件预览(使用浏览器内置预览) */}
|
||||
{isWordFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:h-full rb:overflow-auto rb:bg-white rb:p-6 rb:rounded rb:border rb:border-gray-200">
|
||||
<div
|
||||
className="rb:prose rb:max-w-none"
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isExcelFile() && !error && !loading && (
|
||||
<div className="rb:w-full rb:h-full rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
|
||||
{excelData.map((sheet, index) => (
|
||||
<div key={index} className="rb:mb-6">
|
||||
<h3 className="rb:text-lg rb:font-semibold rb:mb-3">{sheet.sheetName}</h3>
|
||||
{sheet.data.length > 0 && (
|
||||
<Table
|
||||
dataSource={sheet.data.slice(1).map((row, idx) => ({ key: idx, ...row }))}
|
||||
columns={sheet.data[0]?.map((header: any, colIdx: number) => ({
|
||||
title: header || `列 ${colIdx + 1}`,
|
||||
dataIndex: colIdx,
|
||||
key: colIdx,
|
||||
width: 150,
|
||||
})) || []}
|
||||
pagination={false}
|
||||
scroll={{ x: 'max-content' }}
|
||||
size="small"
|
||||
bordered
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isPdfFile() && !error && !loading && (
|
||||
<iframe
|
||||
src={getPreviewUrl()}
|
||||
src={fileUrl}
|
||||
width="100%"
|
||||
height="100%"
|
||||
title={fileName || 'PDF 预览'}
|
||||
className="rb:border-0"
|
||||
style={{ border: 'none' }}
|
||||
onLoad={handleLoad}
|
||||
onError={handleError}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Office 文件预览 */}
|
||||
{!isTextFile() && !isMarkdownFile() && !isImageFile() && !isPdfFile() && (
|
||||
<>
|
||||
{showModeSwitch && !loading && !error && (
|
||||
<div className="rb:absolute rb:top-2 rb:right-2 rb:z-20">
|
||||
<Button size="small" onClick={handleSwitchMode}>
|
||||
切换到 {currentMode === 'office' ? 'Google' : 'Office'} 预览
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!error && (
|
||||
<iframe
|
||||
src={getPreviewUrl()}
|
||||
width="100%"
|
||||
height="100%"
|
||||
onLoad={handleLoad}
|
||||
onError={handleError}
|
||||
title={fileName || '文档预览'}
|
||||
className="rb:border-0"
|
||||
style={{ display: loading ? 'none' : 'block', border: 'none' }}
|
||||
sandbox="allow-scripts allow-same-origin allow-popups"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 15:18:19
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 15:44:42
|
||||
* @Last Modified time: 2026-03-12 18:36:19
|
||||
*/
|
||||
/**
|
||||
* PageScrollList Component
|
||||
@@ -60,8 +60,8 @@ interface PageScrollListProps<T, Q = Record<string, unknown>> {
|
||||
|
||||
/** Infinite scroll list component with pagination support */
|
||||
const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
renderItem,
|
||||
query,
|
||||
renderItem,
|
||||
query,
|
||||
url,
|
||||
column = 4,
|
||||
className = '',
|
||||
@@ -69,68 +69,70 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
}: PageScrollListProps<T, Q>, ref: React.Ref<PageScrollListRef>) => {
|
||||
/** Expose refresh method to parent component */
|
||||
useImperativeHandle(ref, () => ({
|
||||
refresh,
|
||||
refresh: () => {
|
||||
pageRef.current = 1;
|
||||
loadingRef.current = false;
|
||||
setHasMore(true);
|
||||
setData([]);
|
||||
loadMoreData(true);
|
||||
},
|
||||
}));
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [data, setData] = useState<T[]>([]);
|
||||
const [page, setPage] = useState(1);
|
||||
const [hasMore, setHasMore] = useState(true);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const pageRef = useRef(1);
|
||||
const loadingRef = useRef(false);
|
||||
const hasMoreRef = useRef(true);
|
||||
|
||||
/** Load more data from API with pagination */
|
||||
const loadMoreData = (flag?: boolean) => {
|
||||
if (!flag && (loading || !hasMore)) {
|
||||
return;
|
||||
}
|
||||
const loadMoreData = (reset?: boolean) => {
|
||||
if (loadingRef.current || (!reset && !hasMoreRef.current)) return;
|
||||
loadingRef.current = true;
|
||||
setLoading(true);
|
||||
const currentPage = reset ? 1 : pageRef.current;
|
||||
request.get(url, {
|
||||
page: page,
|
||||
page: currentPage,
|
||||
pagesize: PAGE_SIZE,
|
||||
...(query||{}),
|
||||
...(query || {}),
|
||||
})
|
||||
.then((res) => {
|
||||
const response = res as ApiResponse<T>;
|
||||
const results = Array.isArray(response.items) ? response.items : Array.isArray(response) ? response as T[] : [];
|
||||
// Replace data if flag is true, otherwise append
|
||||
if (flag) {
|
||||
setData(results);
|
||||
} else {
|
||||
setData(data.concat(results));
|
||||
}
|
||||
setPage(response.page.page + 1);
|
||||
pageRef.current = response.page.page + 1;
|
||||
setData(prev => reset ? results : [...prev, ...results]);
|
||||
hasMoreRef.current = response.page?.hasnext;
|
||||
setHasMore(response.page?.hasnext);
|
||||
setLoading(false);
|
||||
console.log(`${results.length} more items loaded!`);
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false);
|
||||
hasMoreRef.current = false;
|
||||
setHasMore(false);
|
||||
console.error('Failed to load data');
|
||||
})
|
||||
.finally(() => {
|
||||
loadingRef.current = false;
|
||||
setLoading(false);
|
||||
// 内容不足以填满容器时,主动继续加载
|
||||
setTimeout(() => {
|
||||
const el = scrollRef.current;
|
||||
console.log(el, el?.scrollHeight, el?.clientHeight, hasMoreRef.current)
|
||||
if (el && hasMoreRef.current && el.scrollHeight <= el.clientHeight) {
|
||||
loadMoreData();
|
||||
}
|
||||
}, 0);
|
||||
});
|
||||
};
|
||||
|
||||
/** Reset list to initial state and reload data */
|
||||
const refresh = () => {
|
||||
setPage(1);
|
||||
/** Reset and reload when query parameters change */
|
||||
const queryKey = JSON.stringify(query);
|
||||
useEffect(() => {
|
||||
pageRef.current = 1;
|
||||
loadingRef.current = false;
|
||||
hasMoreRef.current = true;
|
||||
setHasMore(true);
|
||||
setData([]);
|
||||
}
|
||||
loadMoreData(true);
|
||||
}, [queryKey]);
|
||||
|
||||
/** Refresh when query parameters change */
|
||||
useEffect(() => {
|
||||
refresh()
|
||||
}, [query]);
|
||||
|
||||
/** Load initial data when list is reset */
|
||||
useEffect(() => {
|
||||
if (page === 1 && hasMore && data.length === 0) {
|
||||
loadMoreData(true);
|
||||
}
|
||||
}, [page, hasMore, data])
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
@@ -140,7 +142,7 @@ const PageScrollList = forwardRef(<T, Q = Record<string, unknown>>({
|
||||
>
|
||||
<InfiniteScroll
|
||||
dataLength={data.length}
|
||||
next={loadMoreData}
|
||||
next={() => loadMoreData()}
|
||||
hasMore={hasMore}
|
||||
loader={loading && needLoading ? <PageLoading /> : false}
|
||||
// endMessage={<Divider plain>It is all, nothing more 🤐</Divider>}
|
||||
|
||||
@@ -1370,7 +1370,7 @@ export const en = {
|
||||
gotoList: 'Return to Application List',
|
||||
gotoDetail: 'View Details',
|
||||
dify: 'Dify',
|
||||
pleaseUploadFile: 'Please upload workflow file',
|
||||
pleaseUploadFile: 'Please upload file',
|
||||
},
|
||||
userMemory: {
|
||||
userMemory: 'User Memory',
|
||||
@@ -1807,6 +1807,29 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
error_desc: 'API is configured but connection error',
|
||||
|
||||
testConnectionSuccess: 'Test Connection Successful',
|
||||
refreshSuccess: 'Refresh Successful',
|
||||
refreshFailed: 'Refresh Failed',
|
||||
|
||||
// Market related
|
||||
marketSelectTitle: 'Select an MCP Market',
|
||||
marketSelectDesc: 'Choose a market source from the left, configure the connection to browse MCP services',
|
||||
marketRefreshSuccess: 'List refreshed',
|
||||
marketActivated: 'Activated',
|
||||
marketInDatabase: 'In Database',
|
||||
marketAdd: 'Add',
|
||||
marketRefresh: 'Refresh',
|
||||
marketConfigBtn: 'Configure',
|
||||
marketConfigConnection: 'Configure Connection',
|
||||
marketNoData: 'No Data',
|
||||
marketNoDataDesc: 'This market currently has no available services',
|
||||
marketNoSearchResult: 'No Search Results',
|
||||
marketNoSearchResultDesc: 'No matching services found, please try other keywords',
|
||||
marketNoServices: 'No MCP Services Available',
|
||||
marketNotConnected: 'Not Connected to This Market',
|
||||
marketNoServicesDesc: 'This market currently has no available services',
|
||||
marketNotConnectedDesc: 'Click the "Configure" button in the upper right corner to set connection information',
|
||||
marketSearchPlaceholder: 'Search services...',
|
||||
marketVisit: 'Visit Market',
|
||||
serviceEndpoint: 'Service Endpoint URL',
|
||||
serviceEndpointPlaceholder: 'URL of the service endpoint',
|
||||
serviceEndpointExtra: 'Complete access address of the MCP service',
|
||||
@@ -1960,6 +1983,21 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
viewDetail: 'View Details',
|
||||
textLink: 'Test Connection',
|
||||
noResult: 'Processing results will be displayed here',
|
||||
|
||||
marketConfig: 'Configure {{name}}',
|
||||
marketSaveAndConnect: 'Save & Connect',
|
||||
marketUrl: 'Market URL',
|
||||
marketUrlPlaceholder: 'Market URL',
|
||||
marketCopy: 'Copy',
|
||||
marketApiKeyOptional: 'Optional',
|
||||
marketApiKeyRequired: 'API Key is required',
|
||||
marketApiKeyExtra: 'Some markets require an API Key to access the full service list',
|
||||
marketApiKeyPlaceholder: 'Enter API Key to access more services',
|
||||
marketConnectionStatus: 'Connection Status',
|
||||
marketConnected: '● Connected',
|
||||
marketDisconnected: '○ Disconnected',
|
||||
marketConnecting: 'Connecting to {{name}}...',
|
||||
marketConfigUpdated: '{{name}} configuration updated',
|
||||
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
||||
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
||||
},
|
||||
@@ -2008,6 +2046,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
self_optimization: 'Self Optimization',
|
||||
process_evolution: 'Process Evolution',
|
||||
unknown: 'Unknown Node',
|
||||
notes: 'Sticky Note',
|
||||
|
||||
clickToConfigure: 'Click to configure node parameters',
|
||||
nodeProperties: 'Node Properties',
|
||||
@@ -2195,6 +2234,12 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
output_variables: 'Output Variables',
|
||||
refreshTip: 'Sync function signature to code',
|
||||
},
|
||||
notes: {
|
||||
showAuth: 'Show Author',
|
||||
enterLink: 'Enter Link URL',
|
||||
placeholder: 'Enter note...',
|
||||
removeLink: 'Remove Link',
|
||||
},
|
||||
name: 'Key',
|
||||
type: 'Type',
|
||||
value: 'Value',
|
||||
|
||||
@@ -754,7 +754,7 @@ export const zh = {
|
||||
gotoList: '返回应用列表',
|
||||
gotoDetail: '查看详情',
|
||||
dify: 'Dify',
|
||||
pleaseUploadFile: '请上传工作流文件',
|
||||
pleaseUploadFile: '请上传文件',
|
||||
},
|
||||
table: {
|
||||
totalRecords: '共 {{total}} 条记录'
|
||||
@@ -1803,6 +1803,29 @@ export const zh = {
|
||||
error_desc: 'API 已配置但链接异常',
|
||||
|
||||
testConnectionSuccess: '测试连接成功',
|
||||
refreshSuccess: '刷新成功',
|
||||
refreshFailed: '刷新失败',
|
||||
|
||||
// Market 相关
|
||||
marketSelectTitle: '选择一个 MCP 市场',
|
||||
marketSelectDesc: '从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务',
|
||||
marketRefreshSuccess: '列表已刷新',
|
||||
marketActivated: '已激活',
|
||||
marketInDatabase: '已入库',
|
||||
marketAdd: '添加',
|
||||
marketRefresh: '刷新',
|
||||
marketConfigBtn: '配置',
|
||||
marketConfigConnection: '配置连接',
|
||||
marketNoData: '暂无数据',
|
||||
marketNoDataDesc: '该市场暂时没有可用的服务',
|
||||
marketNoSearchResult: '无搜索结果',
|
||||
marketNoSearchResultDesc: '未找到匹配的服务,请尝试其他关键词',
|
||||
marketNoServices: '暂无可用的 MCP 服务',
|
||||
marketNotConnected: '尚未连接此市场',
|
||||
marketNoServicesDesc: '该市场暂时没有可用的服务',
|
||||
marketNotConnectedDesc: '点击右上角"配置"按钮设置连接信息',
|
||||
marketSearchPlaceholder: '搜索服务...',
|
||||
marketVisit: '前往市场',
|
||||
serviceEndpoint: '服务端点 URL',
|
||||
serviceEndpointPlaceholder: '服务端点的 URL',
|
||||
serviceEndpointExtra: 'MCP服务的完整访问地址',
|
||||
@@ -1956,6 +1979,21 @@ export const zh = {
|
||||
viewDetail: '查看详情',
|
||||
textLink: '测试连接',
|
||||
noResult: '处理结果将显示在这里',
|
||||
|
||||
marketConfig: '配置 {{name}}',
|
||||
marketSaveAndConnect: '保存并连接',
|
||||
marketUrl: '市场地址',
|
||||
marketUrlPlaceholder: '市场地址',
|
||||
marketCopy: '复制',
|
||||
marketApiKeyOptional: '可选',
|
||||
marketApiKeyRequired: '请输入 API Key',
|
||||
marketApiKeyExtra: '部分市场需要 API Key 才能获取完整的服务列表',
|
||||
marketApiKeyPlaceholder: '输入 API Key 以获取更多服务',
|
||||
marketConnectionStatus: '连接状态',
|
||||
marketConnected: '● 已连接',
|
||||
marketDisconnected: '○ 未连接',
|
||||
marketConnecting: '正在连接 {{name}}...',
|
||||
marketConfigUpdated: '{{name}} 配置已更新',
|
||||
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
||||
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
||||
},
|
||||
@@ -2004,6 +2042,7 @@ export const zh = {
|
||||
self_optimization: '自我优化',
|
||||
process_evolution: '流程演化',
|
||||
unknown: '未知节点',
|
||||
notes: '便签',
|
||||
|
||||
clickToConfigure: '点击配置节点参数',
|
||||
nodeProperties: '节点属性',
|
||||
@@ -2194,6 +2233,12 @@ export const zh = {
|
||||
unknown: {
|
||||
replaceNodeType: '替换节点'
|
||||
},
|
||||
notes: {
|
||||
showAuth: '显示作者',
|
||||
enterLink: '输入链接 URL',
|
||||
placeholder: '输入注释...',
|
||||
removeLink: '取消链接',
|
||||
},
|
||||
name: '键',
|
||||
type: '类型',
|
||||
value: '值',
|
||||
|
||||
@@ -347,6 +347,23 @@ export const request = {
|
||||
document.body.removeChild(link);
|
||||
callback?.()
|
||||
});
|
||||
},
|
||||
getDownloadFile(url: string, fileName: string, data?: unknown, callback?: () => void) {
|
||||
service.get(url, {
|
||||
params: paramFilter(data as Record<string, string | number | boolean | ObjectWithPush | null | undefined>),
|
||||
responseType: "blob",
|
||||
})
|
||||
.then(res => {
|
||||
const link = document.createElement("a");
|
||||
const blob = new Blob([res as unknown as BlobPart]);
|
||||
link.style.display = "none";
|
||||
link.href = URL.createObjectURL(blob);
|
||||
link.setAttribute("download", decodeURI(fileName || fileName));
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
callback?.()
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Button, Space, Input, Form, App } from 'antd';
|
||||
|
||||
import Tag, { type TagProps } from './components/Tag'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getReleaseList, rollbackRelease } from '@/api/application'
|
||||
import { getReleaseList, rollbackRelease, appExport } from '@/api/application'
|
||||
import ReleaseModal from './components/ReleaseModal'
|
||||
import ReleaseShareModal from './components/ReleaseShareModal'
|
||||
import type { Release, ReleaseModalRef, ReleaseShareModalRef } from './types'
|
||||
@@ -67,6 +67,9 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
|
||||
message.success(t('common.operateSuccess'))
|
||||
})
|
||||
}
|
||||
const handleExport = () => {
|
||||
appExport(data.id, data.name)
|
||||
}
|
||||
return (
|
||||
<div className="rb:flex rb:h-[calc(100vh-64px)]">
|
||||
<div className="rb:h-full rb:overflow-y-auto rb:w-108 rb:flex-[0_0_auto] rb:border-r rb:border-[#DFE4ED] rb:p-4">
|
||||
@@ -123,7 +126,7 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
|
||||
|
||||
<Space size={10}>
|
||||
{selectedVersion && <>
|
||||
{/* <Button>{t('application.exportDSLFile')}</Button> */}
|
||||
{data?.type !== 'multi_agent' && <Button onClick={handleExport}>{t('common.export')}</Button>}
|
||||
{data.current_release_id !== selectedVersion.id && <Button onClick={handleRollback}>{t('application.willRollToThisVersion')}</Button>}
|
||||
<Button type="primary" ghost onClick={() => releaseShareModalRef.current?.handleOpen()}>{t('application.share')}</Button>
|
||||
</>}
|
||||
|
||||
@@ -19,9 +19,8 @@ import deleteIcon from '@/assets/images/delete_hover.svg'
|
||||
import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types';
|
||||
import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal'
|
||||
import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types'
|
||||
import { deleteApplication } from '@/api/application'
|
||||
import { deleteApplication, appExport } from '@/api/application'
|
||||
import CopyModal from './CopyModal'
|
||||
import { exportToYaml } from '@/utils/yamlExport';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -85,16 +84,16 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
* Handle menu item click
|
||||
*/
|
||||
const handleClick: MenuProps['onClick'] = ({ key }) => {
|
||||
if (!application) return
|
||||
switch (key) {
|
||||
case 'edit':
|
||||
applicationModalRef.current?.handleOpen(application as Application)
|
||||
applicationModalRef.current?.handleOpen(application)
|
||||
break;
|
||||
case 'copy':
|
||||
copyModalRef.current?.handleOpen()
|
||||
break;
|
||||
case 'export':
|
||||
console.log('export', workflowRef?.current?.config)
|
||||
exportToYaml(workflowRef?.current?.config, application?.name ?`${application?.name}.yml`: undefined)
|
||||
appExport(application.id, application.name)
|
||||
break;
|
||||
case 'delete':
|
||||
handleDelete()
|
||||
@@ -153,7 +152,7 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
* Format dropdown menu items
|
||||
*/
|
||||
const formatMenuItems = useMemo(() => {
|
||||
const items = (application?.type === 'workflow' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
|
||||
const items = (application?.type !== 'multi_agent' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({
|
||||
key,
|
||||
icon: <img src={menuIcons[key]} className="rb:w-4 rb:h-4 rb:mr-2" />,
|
||||
label: t(`common.${key}`),
|
||||
|
||||
256
web/src/views/ApplicationManagement/components/UploadModal.tsx
Normal file
256
web/src/views/ApplicationManagement/components/UploadModal.tsx
Normal file
@@ -0,0 +1,256 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-28 14:08:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 17:19:46
|
||||
*/
|
||||
/**
|
||||
* UploadModal Component
|
||||
*
|
||||
* This component provides a modal for uploading workflow files with a multi-step process:
|
||||
* 1. Upload - Select platform and file
|
||||
* 2. Complex - Show warnings and errors if any
|
||||
* 3. SureInfo - Confirm and edit workflow information
|
||||
* 4. Completed - Show success message and options
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
|
||||
import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { Application, UploadModalRef } from '../types'
|
||||
import RbModal from '@/components/RbModal'
|
||||
import UploadFiles from '@/components/Upload/UploadFiles'
|
||||
import { appImport } from '@/api/application'
|
||||
|
||||
/**
|
||||
* Props for UploadModal component
|
||||
*/
|
||||
interface UploadModalProps {
|
||||
/** Function to refresh the parent component after workflow import */
|
||||
refresh: () => void;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Steps definition for the upload process
|
||||
*/
|
||||
const steps = [
|
||||
'upload', // Step 1: File upload
|
||||
'complex', // Step 2: Error/warning display
|
||||
'completed' // Step 4: Success message
|
||||
]
|
||||
/**
|
||||
* UploadModal component
|
||||
*
|
||||
* @param {UploadModalProps} props - Component props
|
||||
* @param {React.Ref<UploadModalRef>} ref - Ref for imperative methods
|
||||
*/
|
||||
const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
// State management
|
||||
const [visible, setVisible] = useState(false); // Modal visibility
|
||||
const [form] = Form.useForm<{ file: File[] }>(); // Form instance
|
||||
const [loading, setLoading] = useState(false); // Loading state
|
||||
const [current, setCurrent] = useState<number>(0); // Current step
|
||||
const [appId, setAppId] = useState<string | null>(null); // Imported application ID
|
||||
const [warnings, setWarnings] = useState<string[]>([])
|
||||
|
||||
/**
|
||||
* Handle modal close
|
||||
* Resets all states and form fields
|
||||
*/
|
||||
const handleClose = () => {
|
||||
refresh()
|
||||
setVisible(false);
|
||||
form.resetFields();
|
||||
setCurrent(0);
|
||||
setAppId(null);
|
||||
setLoading(false);
|
||||
setWarnings([])
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle modal open
|
||||
* Resets form fields and shows modal
|
||||
*/
|
||||
const handleOpen = () => {
|
||||
form.resetFields();
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handle save/submit action
|
||||
* Processes different logic based on current step
|
||||
*/
|
||||
const handleSave = () => {
|
||||
const values = form.getFieldsValue();
|
||||
|
||||
switch(current) {
|
||||
case 0: // Step 1: Upload file
|
||||
if (!values.file || values.file.length === 0) {
|
||||
message.warning(t('application.pleaseUploadFile'));
|
||||
return;
|
||||
}
|
||||
const formData = new FormData();
|
||||
formData.append('file', values.file[0]);
|
||||
|
||||
setLoading(true)
|
||||
// Call import API
|
||||
appImport(formData)
|
||||
.then(res => {
|
||||
const { warnings, app } = res as { warnings: string[]; app: Application };
|
||||
|
||||
setAppId(app?.id)
|
||||
if (warnings.length) {
|
||||
setCurrent(1)
|
||||
setWarnings(warnings)
|
||||
} else {
|
||||
setCurrent(2)
|
||||
}
|
||||
})
|
||||
.finally(() => setLoading(false));
|
||||
break;
|
||||
case 2:
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Expose methods to parent component via ref
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
/**
|
||||
* Handle navigation after successful import
|
||||
* @param {string} type - Navigation type ('detail' or 'list')
|
||||
*/
|
||||
const handleJump = (type: string) => {
|
||||
handleClose();
|
||||
refresh();
|
||||
setTimeout(() => {
|
||||
switch (type) {
|
||||
case 'detail':
|
||||
// Open application detail page in new tab
|
||||
window.open(`/#/application/config/${appId}`, '_blank');
|
||||
break;
|
||||
}
|
||||
}, 100)
|
||||
};
|
||||
|
||||
/**
|
||||
* Generate modal footer based on current step
|
||||
*/
|
||||
const getFooter = useMemo(() => {
|
||||
switch (current) {
|
||||
case 0: // Step 1: Upload
|
||||
return [
|
||||
<Button key="back" onClick={handleClose}>
|
||||
{t('common.cancel')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="confirm"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={handleSave}
|
||||
>
|
||||
{t('common.confirm')}
|
||||
</Button>
|
||||
];
|
||||
case 1:
|
||||
return [
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
{t('application.gotoList')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={() => handleJump('detail')}
|
||||
>
|
||||
{t('application.gotoDetail')}
|
||||
</Button>
|
||||
]
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}, [current, loading]);
|
||||
return (
|
||||
<RbModal
|
||||
title={t('application.import')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t('common.confirm')}
|
||||
onOk={handleSave}
|
||||
footer={getFooter}
|
||||
>
|
||||
{/* Steps indicator */}
|
||||
<div className='rb:p-3 rb:bg-[#FBFDFF] rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:mb-3'>
|
||||
<Steps
|
||||
labelPlacement="vertical"
|
||||
size="small"
|
||||
current={current}
|
||||
items={steps.map(key => ({ title: t(`application.${key}`) }))}
|
||||
/>
|
||||
</div>
|
||||
{current === 0 &&
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="file"
|
||||
valuePropName="fileList"
|
||||
noStyle
|
||||
>
|
||||
<UploadFiles
|
||||
isAutoUpload={false}
|
||||
isCanDrag={true}
|
||||
fileSize={100}
|
||||
maxCount={1}
|
||||
fileType={['yml']}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
}
|
||||
{/* Step 2: Error/warning display */}
|
||||
{current === 1 &&
|
||||
<Flex vertical gap={12}>
|
||||
{warnings.map((vo, index) => (
|
||||
<Alert
|
||||
key={index}
|
||||
message={<div>{vo}</div>}
|
||||
type="warning"
|
||||
showIcon
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
}
|
||||
{current === 2 &&
|
||||
<Result
|
||||
status="success"
|
||||
title={t('application.importSuccess')}
|
||||
subTitle={t('application.importSuccessDesc')}
|
||||
extra={[
|
||||
<Button key="back" onClick={() => handleJump('list')}>
|
||||
{t('application.gotoList')}
|
||||
</Button>,
|
||||
<Button
|
||||
key="submit"
|
||||
type="primary"
|
||||
loading={loading}
|
||||
onClick={() => handleJump('detail')}
|
||||
>
|
||||
{t('application.gotoDetail')}
|
||||
</Button>
|
||||
]}
|
||||
/>
|
||||
}
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default UploadModal;
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-28 14:08:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-06 12:05:46
|
||||
* @Last Modified time: 2026-03-12 17:19:33
|
||||
*/
|
||||
/**
|
||||
* UploadWorkflowModal Component
|
||||
@@ -72,6 +72,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
||||
setFirstFormData(null);
|
||||
setAppId(null);
|
||||
setLoading(false);
|
||||
refresh()
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -25,6 +25,7 @@ import { getApplicationListUrl, deleteApplication } from '@/api/application'
|
||||
import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList'
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import UploadWorkflowModal from './components/UploadWorkflowModal'
|
||||
import UploadModal from './components/UploadModal'
|
||||
|
||||
/**
|
||||
* Application management main component
|
||||
@@ -37,6 +38,7 @@ const ApplicationManagement: React.FC = () => {
|
||||
const applicationModalRef = useRef<ApplicationModalRef>(null);
|
||||
const scrollListRef = useRef<PageScrollListRef>(null)
|
||||
const uploadWorkflowModalRef = useRef<UploadWorkflowModalRef>(null);
|
||||
const uploadModalRef = useRef<UploadWorkflowModalRef>(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Convert URLSearchParams to a plain object for easier access
|
||||
@@ -91,6 +93,8 @@ const ApplicationManagement: React.FC = () => {
|
||||
case 'thirdParty':
|
||||
handleImport()
|
||||
break;
|
||||
case 'import':
|
||||
uploadModalRef.current?.handleOpen()
|
||||
}
|
||||
}
|
||||
return (
|
||||
@@ -121,6 +125,7 @@ const ApplicationManagement: React.FC = () => {
|
||||
<Dropdown
|
||||
menu={{ items: [
|
||||
{ key: 'thirdParty', label: t('application.importWorkflow') },
|
||||
{ key: 'import', label: t('application.import') },
|
||||
], onClick: handleClick }}
|
||||
placement="bottomRight"
|
||||
>
|
||||
@@ -186,6 +191,10 @@ const ApplicationManagement: React.FC = () => {
|
||||
ref={uploadWorkflowModalRef}
|
||||
refresh={refresh}
|
||||
/>
|
||||
<UploadModal
|
||||
ref={uploadModalRef}
|
||||
refresh={refresh}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -233,4 +233,12 @@ export interface UploadData extends WorkflowConfig {
|
||||
export interface UploadWorkflowModalRef {
|
||||
/** Open the upload workflow modal */
|
||||
handleOpen: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upload app modal ref interface
|
||||
*/
|
||||
export interface UploadModalRef {
|
||||
/** Open the upload workflow modal */
|
||||
handleOpen: () => void;
|
||||
}
|
||||
@@ -82,6 +82,7 @@ const CreateDataset = () => {
|
||||
const [form] = Form.useForm<ContentFormData>();
|
||||
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
|
||||
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
|
||||
const [textFormValid, setTextFormValid] = useState<boolean>(false);
|
||||
|
||||
const [pollingLoading, setPollingLoading] = useState<boolean>(false);
|
||||
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
@@ -624,7 +625,16 @@ const CreateDataset = () => {
|
||||
)}
|
||||
{source && source === 'text' && (
|
||||
<div className='rb:flex rb:w-full rb:flex-col rb:mt-10 rb:px-40'>
|
||||
<Form form={form} layout="vertical">
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
onValuesChange={() => {
|
||||
// 检查表单字段是否都已填写
|
||||
const values = form.getFieldsValue();
|
||||
const isValid = !!(values.title?.trim() && values.content?.trim());
|
||||
setTextFormValid(isValid);
|
||||
}}
|
||||
>
|
||||
<Form.Item
|
||||
name="title"
|
||||
label={t('knowledgeBase.title')}
|
||||
@@ -845,7 +855,11 @@ const CreateDataset = () => {
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={current === 2 ? handleStartUpload : handleNext}
|
||||
disabled={pollingLoading || (current === 0 && rechunkFileIds.length === 0)}
|
||||
disabled={
|
||||
pollingLoading ||
|
||||
(current === 0 && source === 'local' && rechunkFileIds.length === 0) ||
|
||||
(current === 0 && source === 'text' && !textFormValid)
|
||||
}
|
||||
>
|
||||
{current === 2 ? t('knowledgeBase.startUploading') || 'Start Upload' : t('common.next') || 'Next'}
|
||||
</Button>
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* @Author: yujiangping
|
||||
* @Date: 2025-11-10 18:52:55
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-03-03 14:46:08
|
||||
* @LastEditTime: 2026-03-09 16:39:07
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
|
||||
import { Switch } from 'antd';
|
||||
@@ -58,16 +58,21 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
|
||||
}
|
||||
|
||||
const handleShare = async() => {
|
||||
const workspaceIds = spaceList
|
||||
.map(item => item.target_kb?.workspace_id)
|
||||
.filter(Boolean)
|
||||
.join(',');
|
||||
|
||||
console.log('Workspace IDs:', workspaceIds);
|
||||
shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
|
||||
|
||||
// Close modal after sharing
|
||||
handleClose();
|
||||
setLoading(true);
|
||||
try {
|
||||
const workspaceIds = spaceList
|
||||
.map(item => item.target_kb?.workspace_id)
|
||||
.filter(Boolean)
|
||||
.join(',');
|
||||
|
||||
console.log('Workspace IDs:', workspaceIds);
|
||||
shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
|
||||
|
||||
// Close modal after sharing
|
||||
handleClose();
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
const handleChange = (checked: boolean, item: any) => {
|
||||
// Toggle shared knowledge base status
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* @Author: yujiangping
|
||||
* @Date: 2025-11-10 18:52:55
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2025-12-03 18:44:58
|
||||
* @LastEditTime: 2026-03-09 16:34:51
|
||||
*/
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Switch } from 'antd';
|
||||
@@ -50,34 +50,38 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
|
||||
setSpaceList(filteredItems as SpaceItem[]);
|
||||
}
|
||||
const handleShare = async() => {
|
||||
|
||||
// Get all data with checked = true
|
||||
const checkedItems = spaceList.filter(item => item.is_active);
|
||||
debugger
|
||||
// Get currently selected item (corresponding to curIndex)
|
||||
const selectedItem = curIndex !== -1 ? spaceList[curIndex] : null;
|
||||
if(!selectedItem){
|
||||
messageApi.error(t('knowledgeBase.selectSpace'));
|
||||
return;
|
||||
}
|
||||
const payload = {
|
||||
source_kb_id: kbId ?? '',
|
||||
target_workspace_id: selectedItem?.id ?? '',
|
||||
}
|
||||
const respose = await shareKnowledgeBase(payload)
|
||||
if(respose){
|
||||
messageApi.success(t('knowledgeBase.shareSuccess'));
|
||||
}else{
|
||||
messageApi.error(t('knowledgeBase.shareFailed'));
|
||||
}
|
||||
// Call parent component's callback function with selected data
|
||||
onShare?.({
|
||||
checkedItems,
|
||||
selectedItem
|
||||
});
|
||||
|
||||
// Close modal after sharing
|
||||
handleClose();
|
||||
setLoading(true);
|
||||
try {
|
||||
const payload = {
|
||||
source_kb_id: kbId ?? '',
|
||||
target_workspace_id: selectedItem?.id ?? '',
|
||||
}
|
||||
const respose = await shareKnowledgeBase(payload)
|
||||
if(respose){
|
||||
messageApi.success(t('knowledgeBase.shareSuccess'));
|
||||
}else{
|
||||
messageApi.error(t('knowledgeBase.shareFailed'));
|
||||
}
|
||||
// Call parent component's callback function with selected data
|
||||
onShare?.({
|
||||
checkedItems,
|
||||
selectedItem
|
||||
});
|
||||
|
||||
// Close modal after sharing
|
||||
handleClose();
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
const handleClick = (index: number, checked: boolean) => {
|
||||
if (!checked) return;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:28
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 11:31:43
|
||||
* @Last Modified time: 2026-03-11 15:08:24
|
||||
*/
|
||||
/**
|
||||
* Custom Model Modal
|
||||
@@ -11,7 +11,7 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useEffect, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App, Checkbox } from 'antd';
|
||||
import { Form, Input, App, Checkbox, Button } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types';
|
||||
@@ -35,6 +35,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<CustomModelForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
const modelType = Form.useWatch(['type'], form);
|
||||
const isOmni = Form.useWatch(['is_omni'], form);
|
||||
|
||||
@@ -46,6 +47,8 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
|
||||
/** Close modal and reset state */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ModelListItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
@@ -73,8 +76,10 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
/** Update or create custom model */
|
||||
const handleUpdate = (data: CustomModelForm) => {
|
||||
setLoading(true)
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
const { type, provider, ...rest} = data
|
||||
const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data)
|
||||
const res = isEdit ? updateCustomModel(model.id, rest, controller.signal) : addCustomModel(data, controller.signal)
|
||||
|
||||
res.then(() => {
|
||||
refresh?.(isEdit)
|
||||
@@ -124,15 +129,15 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
console.log('modelType', modelType)
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createCustomModel')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
|
||||
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.${isEdit ? 'save' : 'create'}`)}</Button>,
|
||||
]}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:40
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:49:40
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-11 15:12:17
|
||||
*/
|
||||
/**
|
||||
* Key Configuration Modal
|
||||
@@ -11,7 +11,7 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { Form, Input, App, Button } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { KeyConfigModalForm, ProviderModelItem, KeyConfigModalRef, KeyConfigModalProps } from '../types';
|
||||
@@ -30,9 +30,12 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
const [model, setModel] = useState<ProviderModelItem>({} as ProviderModelItem);
|
||||
const [form] = Form.useForm<KeyConfigModalForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
|
||||
/** Close modal and reset state */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ProviderModelItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
@@ -51,10 +54,13 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
|
||||
updateProviderApiKeys({
|
||||
...values,
|
||||
provider: model.provider
|
||||
}).then((res) => {
|
||||
}, controller.signal).then((res) => {
|
||||
if (refresh) {
|
||||
refresh();
|
||||
}
|
||||
@@ -81,9 +87,10 @@ const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
title={`${model.provider} - ${t('modelNew.keyConfig')}`}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.save`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
footer={[
|
||||
<Button key="cancel" onClick={handleClose}>{t('common.cancel')}</Button>,
|
||||
<Button key="confirm" type="primary" loading={loading} onClick={handleSave}>{t(`common.save`)}</Button>,
|
||||
]}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:55
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:49:55
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-11 15:11:06
|
||||
*/
|
||||
/**
|
||||
* Multi-Key Configuration Modal
|
||||
@@ -28,9 +28,12 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
|
||||
const [form] = Form.useForm<MultiKeyForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [abortController, setAbortController] = useState<AbortController | null>(null)
|
||||
|
||||
/** Close modal and refresh parent */
|
||||
const handleClose = () => {
|
||||
abortController?.abort()
|
||||
setAbortController(null)
|
||||
setModel({} as ModelListItem);
|
||||
refresh?.()
|
||||
|
||||
@@ -60,12 +63,14 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
const controller = new AbortController()
|
||||
setAbortController(controller)
|
||||
addModelApiKey(model.id, {
|
||||
...values,
|
||||
model_config_id: model.id,
|
||||
model_name: model.name,
|
||||
provider: model.provider,
|
||||
}).then(() => {
|
||||
}, controller.signal).then(() => {
|
||||
message.success(t('common.saveSuccess'))
|
||||
form.resetFields();
|
||||
getData(model)
|
||||
@@ -98,7 +103,6 @@ const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigMod
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
footer={null}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
{model.api_keys && model.api_keys.length > 0 && (
|
||||
<div className="rb:mb-4">
|
||||
|
||||
@@ -1,131 +1,342 @@
|
||||
import React, { useState, useRef, type ReactNode } from 'react';
|
||||
import { Input, Button, Spin, App } from 'antd';
|
||||
import React, { useState, useRef, useEffect, useCallback, type ReactNode } from 'react';
|
||||
import { Input, Button, App, Card, Space, Skeleton, Tag } from 'antd';
|
||||
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import InfiniteScroll from 'react-infinite-scroll-component';
|
||||
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
|
||||
|
||||
import McpServiceModal from './components/McpServiceModal';
|
||||
import type { McpServiceModalRef } from './types';
|
||||
import pageEmptyIcon from '@/assets/images/empty/pageEmpty.png'
|
||||
import Empty from '@/components/Empty/index'
|
||||
import { getMarketTools, getMarketConfig, getMarketMCPs, getMarketMCPDetail, getMarketMCPsActivated, getTools } from '@/api/tools';
|
||||
interface MarketSource {
|
||||
id: string;
|
||||
name: string;
|
||||
category: string;
|
||||
icon: string;
|
||||
logo_url: string;
|
||||
url: string;
|
||||
desc: string;
|
||||
apiKey: string;
|
||||
description: string;
|
||||
api_key?: string;
|
||||
connected: boolean;
|
||||
mcpCount: number;
|
||||
mcp_count: number;
|
||||
created_at?: number;
|
||||
created_by?: string;
|
||||
}
|
||||
|
||||
interface MarketMcp {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
type: string;
|
||||
desc: string;
|
||||
downloads?: string;
|
||||
stars?: string;
|
||||
icon: string;
|
||||
configTemplate: any;
|
||||
chinese_name?: string;
|
||||
description: string;
|
||||
logo_url: string;
|
||||
publisher: string;
|
||||
categories?: string[];
|
||||
tags?: string[];
|
||||
view_count?: number;
|
||||
activated?: boolean;
|
||||
inDatabase?: boolean;
|
||||
locales?: {
|
||||
[lang: string]: {
|
||||
name: string;
|
||||
description: string;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
interface MarketCategory {
|
||||
id: string;
|
||||
name: string;
|
||||
icon: string;
|
||||
}
|
||||
|
||||
interface MarketApiResponse {
|
||||
items: MarketSource[];
|
||||
}
|
||||
|
||||
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
|
||||
const { t } = useTranslation();
|
||||
const { t, i18n } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
|
||||
const getLocaleField = (mcp: MarketMcp, field: 'name' | 'description') => {
|
||||
const lang = i18n.language?.startsWith('zh') ? 'zh' : 'en';
|
||||
return mcp.locales?.[lang]?.[field] || mcp[field] || '';
|
||||
};
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [selectedSource, setSelectedSource] = useState<string | null>(null);
|
||||
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
|
||||
const [marketSources, setMarketSources] = useState<MarketSource[]>([
|
||||
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 },
|
||||
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 },
|
||||
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 },
|
||||
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 },
|
||||
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
|
||||
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
|
||||
]);
|
||||
|
||||
const [categories] = useState<MarketCategory[]>([
|
||||
{ id: 'official', name: '官方/综合', icon: '🌐' },
|
||||
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
|
||||
{ id: 'community', name: '社区/垂直', icon: '👥' }
|
||||
]);
|
||||
|
||||
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
|
||||
'github-mcp': [
|
||||
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
|
||||
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
|
||||
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成,支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
|
||||
]
|
||||
});
|
||||
|
||||
const mcpServiceModalRef = useRef<McpServiceModalRef>(null);
|
||||
const [marketSources, setMarketSources] = useState<MarketSource[]>([]);
|
||||
const [categories, setCategories] = useState<MarketCategory[]>([]);
|
||||
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({});
|
||||
const [mcpTotal, setMcpTotal] = useState(0);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [configIdMap, setConfigIdMap] = useState<Record<string, string>>({});
|
||||
const [hasMore, setHasMore] = useState(false);
|
||||
const [activatedMcps, setActivatedMcps] = useState<string[]>([]);
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const pageSize = 20;
|
||||
const searchTimerRef = useRef<number | null>(null);
|
||||
|
||||
const handleSelectSource = (sourceId: string) => {
|
||||
setSelectedSource(sourceId);
|
||||
};
|
||||
|
||||
const handleRefresh = (sourceId: string) => {
|
||||
setLoading(true);
|
||||
setTimeout(() => {
|
||||
// 模拟刷新数据
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source) {
|
||||
message.success(`${source.name} 列表已刷新`);
|
||||
// 获取市场数据
|
||||
useEffect(() => {
|
||||
const fetchMarketData = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const response = await getMarketTools({}) as MarketApiResponse;
|
||||
if (response?.items && Array.isArray(response.items)) {
|
||||
setMarketSources(response.items);
|
||||
|
||||
// 根据 category 字段分组
|
||||
const categoryMap = new Map<string, MarketCategory>();
|
||||
response.items.forEach(item => {
|
||||
if (item.category && !categoryMap.has(item.category)) {
|
||||
categoryMap.set(item.category, {
|
||||
id: item.category,
|
||||
name: item.category
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
setCategories(Array.from(categoryMap.values()));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取市场数据失败:', error);
|
||||
message.error('获取市场数据失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchMarketData();
|
||||
}, [message]);
|
||||
|
||||
const fetchMcpList = async (sourceId: string, page = 1, append = false, keywords = '') => {
|
||||
setLoading(true);
|
||||
try {
|
||||
let configId = configIdMap[sourceId];
|
||||
|
||||
// 如果没有缓存 configId,先获取配置
|
||||
if (!configId) {
|
||||
const config: any = await getMarketConfig(sourceId);
|
||||
if (config?.id) {
|
||||
configId = config.id;
|
||||
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 第一次加载时获取已激活列表
|
||||
let activatedIds: string[] = activatedMcps;
|
||||
if (page === 1 && !append) {
|
||||
const activatedRes: any = await getMarketMCPsActivated({ mcp_market_config_id: configId });
|
||||
if (activatedRes && Array.isArray(activatedRes)) {
|
||||
activatedIds = activatedRes.map((item: any) => item.id);
|
||||
setActivatedMcps(activatedIds);
|
||||
}
|
||||
}
|
||||
|
||||
// 获取全量工具列表,用于标记已入库的 MCP
|
||||
const allTools: any = await getTools({ tool_type: 'mcp' });
|
||||
const toolsList = Array.isArray(allTools) ? allTools : [];
|
||||
|
||||
const res: any = await getMarketMCPs({
|
||||
mcp_market_config_id: configId,
|
||||
page,
|
||||
pagesize: pageSize,
|
||||
...(keywords ? { keywords } : {})
|
||||
});
|
||||
if (res?.items && Array.isArray(res.items)) {
|
||||
// 标记已激活和已入库的 MCP
|
||||
const mcpsWithActivated = res.items.map((item: MarketMcp) => {
|
||||
// 检查是否已入库:market_id = sourceId, market_config_id = configId, mcp_service_id = item.id
|
||||
const isInDatabase = toolsList.some((tool: any) =>
|
||||
tool.config_data?.market_id === sourceId &&
|
||||
tool.config_data?.market_config_id === configId &&
|
||||
tool.config_data?.mcp_service_id === item.id
|
||||
);
|
||||
|
||||
return {
|
||||
...item,
|
||||
activated: activatedIds.includes(item.id),
|
||||
inDatabase: isInDatabase
|
||||
};
|
||||
});
|
||||
|
||||
setMcpCache(prev => ({
|
||||
...prev,
|
||||
[sourceId]: append ? [...(prev[sourceId] || []), ...mcpsWithActivated] : mcpsWithActivated
|
||||
}));
|
||||
}
|
||||
if (res?.page) {
|
||||
setMcpTotal(res.page.total || 0);
|
||||
setHasMore(!!res.page.has_next);
|
||||
setCurrentPage(res.page.page || page);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取 MCP 列表失败:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}, 600);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenConfig = (sourceId: string) => {
|
||||
const loadMore = useCallback(() => {
|
||||
if (!selectedSource || loading) return;
|
||||
fetchMcpList(selectedSource, currentPage + 1, true, searchKeyword);
|
||||
}, [selectedSource, currentPage, loading, searchKeyword]);
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
setSearchKeyword(value);
|
||||
|
||||
// 清除之前的定时器
|
||||
if (searchTimerRef.current) {
|
||||
clearTimeout(searchTimerRef.current);
|
||||
}
|
||||
|
||||
// 如果清空搜索框,恢复原始列表
|
||||
if (!value.trim()) {
|
||||
if (selectedSource) {
|
||||
// 清除缓存,重新加载原始列表
|
||||
setMcpCache(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[selectedSource];
|
||||
return next;
|
||||
});
|
||||
setCurrentPage(1);
|
||||
fetchMcpList(selectedSource, 1, false, '');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 设置新的定时器,500ms 后执行搜索
|
||||
searchTimerRef.current = setTimeout(() => {
|
||||
if (selectedSource) {
|
||||
// 清除缓存,重新搜索
|
||||
setMcpCache(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[selectedSource];
|
||||
return next;
|
||||
});
|
||||
setCurrentPage(1);
|
||||
fetchMcpList(selectedSource, 1, false, value);
|
||||
}
|
||||
}, 500);
|
||||
};
|
||||
|
||||
const handleSelectSource = async (sourceId: string) => {
|
||||
setSelectedSource(sourceId);
|
||||
setSearchKeyword('');
|
||||
setCurrentPage(1);
|
||||
setHasMore(false);
|
||||
setMcpTotal(0);
|
||||
|
||||
// 如果缓存中已有数据,直接使用
|
||||
if (mcpCache[sourceId]) return;
|
||||
|
||||
await fetchMcpList(sourceId, 1);
|
||||
};
|
||||
|
||||
const handleRefresh = async (sourceId: string) => {
|
||||
// 清除缓存,重新从第一页加载
|
||||
setMcpCache(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[sourceId];
|
||||
return next;
|
||||
});
|
||||
setCurrentPage(1);
|
||||
await fetchMcpList(sourceId, 1);
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source) {
|
||||
message.success(`${source.name} ${t('tool.marketRefreshSuccess')}`);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenConfig = async (sourceId: string) => {
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (!source) return;
|
||||
try {
|
||||
const config: any = await getMarketConfig(sourceId);
|
||||
console.log('获取到的配置数据:', config);
|
||||
marketConfigModalRef.current?.handleOpen({
|
||||
...source,
|
||||
connected: config?.status === 1,
|
||||
token: config?.token || '',
|
||||
configId: config?.id || '',
|
||||
});
|
||||
} catch {
|
||||
marketConfigModalRef.current?.handleOpen(source);
|
||||
}
|
||||
};
|
||||
|
||||
const handleConnect = (sourceId: string, apiKey: string) => {
|
||||
// 更新市场源状态
|
||||
const handleOpenMcpServiceModal = async (mcp: MarketMcp) => {
|
||||
if (!selectedSource || !configIdMap[selectedSource]) return;
|
||||
try {
|
||||
const detail: any = await getMarketMCPDetail({
|
||||
mcp_market_config_id: configIdMap[selectedSource],
|
||||
server_id: mcp.id,
|
||||
});
|
||||
const source = marketSources.find(s => s.id === selectedSource);
|
||||
const toolItem = {
|
||||
name: detail.name,
|
||||
description: detail.description,
|
||||
source_channel: source?.name || '',
|
||||
market_id: selectedSource,
|
||||
market_config_id: configIdMap[selectedSource],
|
||||
mcp_service_id: mcp.id,
|
||||
config_data: {
|
||||
server_url: detail.servers?.[0]?.url || '',
|
||||
connection_config: {
|
||||
auth_type: 'none',
|
||||
timeout: 30,
|
||||
headers: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
mcpServiceModalRef.current?.handleOpen(toolItem as any);
|
||||
} catch (error) {
|
||||
console.error('获取 MCP 服务详情失败:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleConnect = async (sourceId: string, configId: string) => {
|
||||
// 更新市场源状态,缓存 configId
|
||||
setMarketSources(prev => prev.map(source => {
|
||||
if (source.id === sourceId) {
|
||||
return {
|
||||
...source,
|
||||
apiKey,
|
||||
connected: true
|
||||
};
|
||||
return { ...source, connected: true };
|
||||
}
|
||||
return source;
|
||||
}));
|
||||
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
|
||||
|
||||
// 模拟获取MCP列表
|
||||
setTimeout(() => {
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source && !mcpCache[sourceId]) {
|
||||
// 生成模拟数据
|
||||
const mockData: MarketMcp[] = [
|
||||
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
|
||||
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
|
||||
];
|
||||
setMcpCache(prev => ({
|
||||
...prev,
|
||||
[sourceId]: mockData
|
||||
}));
|
||||
}
|
||||
message.success(`已连接 ${source?.name}`);
|
||||
}, 800);
|
||||
// 使用 fetchMcpList 获取完整的 MCP 列表(包含激活状态和入库状态)
|
||||
await fetchMcpList(sourceId, 1);
|
||||
};
|
||||
|
||||
const handleRefreshAfterAdd = async () => {
|
||||
// 添加成功后,刷新当前选中的市场源的 MCP 列表
|
||||
if (!selectedSource) return;
|
||||
|
||||
// 清除缓存并重新加载,这样会重新获取工具列表并更新 inDatabase 标记
|
||||
setMcpCache(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[selectedSource];
|
||||
return next;
|
||||
});
|
||||
setCurrentPage(1);
|
||||
await fetchMcpList(selectedSource, 1);
|
||||
};
|
||||
|
||||
const renderSourceDetail = () => {
|
||||
if (!selectedSource) {
|
||||
return (
|
||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
|
||||
<div className="rb:text-6xl rb:mb-4">🏪</div>
|
||||
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2">选择一个 MCP 市场</h3>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md">从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务</p>
|
||||
<Empty
|
||||
url={pageEmptyIcon}
|
||||
title={t('tool.marketSelectTitle')}
|
||||
subTitle={t('tool.marketSelectDesc')}
|
||||
size={200}
|
||||
className="rb:h-full"
|
||||
/>
|
||||
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -134,170 +345,228 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
if (!source) return null;
|
||||
|
||||
const mcpList = mcpCache[selectedSource] || [];
|
||||
const filteredList = mcpList.filter(mcp =>
|
||||
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
|
||||
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase())
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6">
|
||||
<div className="rb:flex rb:gap-4">
|
||||
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0">
|
||||
{source.icon}
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:pb-0">
|
||||
<div className="rb:flex rb:items-center rb:gap-4">
|
||||
<div className="rb:w-10 rb:h-10 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0 rb:overflow-hidden">
|
||||
{source.logo_url ? (
|
||||
<img
|
||||
src={source.logo_url}
|
||||
alt={source.name}
|
||||
className="rb:w-full rb:h-full rb:object-cover"
|
||||
referrerPolicy="no-referrer"
|
||||
onError={(e) => {
|
||||
e.currentTarget.style.display = 'none';
|
||||
const parent = e.currentTarget.parentElement;
|
||||
if (parent) {
|
||||
parent.innerHTML = '🏪';
|
||||
parent.style.fontSize = '48px';
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<span className="rb:text-5xl">🏪</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="rb:flex-1">
|
||||
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p>
|
||||
<div className="rb:flex rb:items-center rb:flex-1">
|
||||
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2 rb:mr-2">{source.name}</h2>
|
||||
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpTotal})</span>
|
||||
{/* <p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.description}</p> */}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="rb:flex rb:gap-3">
|
||||
<div className="rb:flex rb:gap-3 rb:items-center">
|
||||
{source.connected && (
|
||||
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
||||
{t('tool.marketRefresh')}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
<Input
|
||||
prefix={<SearchOutlined />}
|
||||
placeholder={t('tool.marketSearchPlaceholder')}
|
||||
value={searchKeyword}
|
||||
onChange={(e) => handleSearchChange(e.target.value)}
|
||||
allowClear
|
||||
style={{ width: 200 }}
|
||||
|
||||
/>
|
||||
|
||||
</div>
|
||||
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
||||
配置
|
||||
{t('tool.marketConfigBtn')}
|
||||
</Button>
|
||||
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
|
||||
前往市场
|
||||
{t('tool.marketVisit')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="rb:mt-6">
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
|
||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
|
||||
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
|
||||
</h3>
|
||||
<div className="rb:flex rb:gap-3 rb:items-center">
|
||||
{source.connected && (
|
||||
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
||||
刷新
|
||||
</Button>
|
||||
)}
|
||||
{mcpList.length > 0 && (
|
||||
<Input
|
||||
prefix={<SearchOutlined />}
|
||||
placeholder="搜索服务..."
|
||||
value={searchKeyword}
|
||||
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||
style={{ width: 200 }}
|
||||
<div id="mcpScrollableDiv" className="rb:overflow-y-auto rb:h-[calc(100vh-260px)]">
|
||||
{!loading && mcpList.length === 0 ? (
|
||||
<Empty
|
||||
url={pageEmptyIcon}
|
||||
title={searchKeyword ? t('tool.marketNoSearchResult') : t('tool.marketNoData')}
|
||||
subTitle={searchKeyword ? t('tool.marketNoSearchResultDesc') : t('tool.marketNoDataDesc')}
|
||||
size={200}
|
||||
className="rb:h-full"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{mcpList.length > 0 ? (
|
||||
<Spin spinning={loading}>
|
||||
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4">
|
||||
{filteredList.map(mcp => (
|
||||
) : (
|
||||
<InfiniteScroll
|
||||
dataLength={mcpList.length}
|
||||
next={loadMore}
|
||||
hasMore={hasMore}
|
||||
loader={null}
|
||||
scrollableTarget="mcpScrollableDiv"
|
||||
>
|
||||
<div
|
||||
className="rb:gap-4"
|
||||
style={{
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(auto-fill, minmax(300px, 1fr))'
|
||||
}}
|
||||
>
|
||||
{mcpList.map(mcp => (
|
||||
<div
|
||||
key={mcp.id}
|
||||
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
||||
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:pb-2 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
||||
>
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
|
||||
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg">
|
||||
{mcp.icon}
|
||||
<div className="rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg rb:overflow-hidden">
|
||||
{mcp.logo_url ? (
|
||||
<img
|
||||
src={mcp.logo_url}
|
||||
alt={getLocaleField(mcp, 'name')}
|
||||
className="rb:w-full rb:h-full rb:object-cover"
|
||||
referrerPolicy="no-referrer"
|
||||
onError={(e) => {
|
||||
e.currentTarget.style.display = 'none';
|
||||
const parent = e.currentTarget.parentElement;
|
||||
if (parent) {
|
||||
parent.innerHTML = '🔧';
|
||||
parent.style.fontSize = '24px';
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<span className="rb:text-3xl">🔧</span>
|
||||
)}
|
||||
</div>
|
||||
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${
|
||||
mcp.type === 'Hosted'
|
||||
? 'rb:bg-blue-50 rb:text-blue-700'
|
||||
: 'rb:bg-gray-100 rb:text-gray-600'
|
||||
}`}>
|
||||
{mcp.type}
|
||||
</span>
|
||||
{mcp.categories?.[0] && (
|
||||
<span className="rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium rb:bg-blue-50 rb:text-blue-700">
|
||||
{mcp.categories[0]}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3>
|
||||
{mcp.provider && (
|
||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{getLocaleField(mcp, 'name')}</h3>
|
||||
{mcp.publisher && (
|
||||
<div className="rb:mb-2">
|
||||
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span>
|
||||
<span className="rb:text-xs rb:text-gray-500">{mcp.publisher.startsWith('@') ? mcp.publisher : `@${mcp.publisher}`}</span>
|
||||
</div>
|
||||
)}
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:line-clamp-2 rb:mb-3 rb:min-h-10">{getLocaleField(mcp, 'description')}</p>
|
||||
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
|
||||
{mcp.downloads && (
|
||||
{mcp.view_count != null && (
|
||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||
<GlobalOutlined /> {mcp.downloads}
|
||||
</span>
|
||||
)}
|
||||
{mcp.stars && (
|
||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||
⭐ {mcp.stars}
|
||||
<GlobalOutlined /> {mcp.view_count.toLocaleString()}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="rb:flex rb:justify-end">
|
||||
<Button type="primary" size="small">
|
||||
+ 添加
|
||||
<div className={`rb:flex rb:items-center ${mcp.activated || mcp.inDatabase ? 'rb:justify-between' : 'rb:justify-end'}`}>
|
||||
<div className="rb:flex rb:gap-2">
|
||||
{mcp.activated && <Tag color="success">{t('tool.marketActivated')}</Tag>}
|
||||
{mcp.inDatabase && <Tag color="blue">{t('tool.marketInDatabase')}</Tag>}
|
||||
</div>
|
||||
<Button disabled={mcp.inDatabase} type="primary" size="small" onClick={() => handleOpenMcpServiceModal(mcp)}>
|
||||
+ {t('tool.marketAdd')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</Spin>
|
||||
) : (
|
||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
|
||||
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
|
||||
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
|
||||
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
|
||||
</h4>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
|
||||
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
|
||||
</p>
|
||||
{!source.connected && (
|
||||
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
|
||||
配置连接
|
||||
</Button>
|
||||
</div>
|
||||
</InfiniteScroll>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]">
|
||||
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-138px)]">
|
||||
{/* 左侧市场源列表 */}
|
||||
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0">
|
||||
<div className="rb:p-4 rb:border-b rb:border-gray-200">
|
||||
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP 市场</span>
|
||||
</div>
|
||||
{categories.map(cat => (
|
||||
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0">
|
||||
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase">
|
||||
<span className="rb:text-sm">{cat.icon}</span>
|
||||
<span>{cat.name}</span>
|
||||
</div>
|
||||
<div className="rb:px-2 rb:py-1">
|
||||
{marketSources
|
||||
.filter(s => s.category === cat.id)
|
||||
.map(source => (
|
||||
<div
|
||||
key={source.id}
|
||||
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${
|
||||
selectedSource === source.id
|
||||
? 'rb:bg-blue-50 rb:text-blue-600'
|
||||
: 'hover:rb:bg-gray-50'
|
||||
}`}
|
||||
onClick={() => handleSelectSource(source.id)}
|
||||
>
|
||||
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span>
|
||||
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||
{source.name}
|
||||
</span>
|
||||
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full">
|
||||
{source.mcpCount}
|
||||
</span>
|
||||
{source.connected && (
|
||||
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1">●</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<div className="rb:w-80 rb:h-full rb:overflow-y-auto">
|
||||
<Space size={12} direction="vertical" className="rb:w-full">
|
||||
{categories.map(cat => (
|
||||
<Card
|
||||
key={cat.id}
|
||||
type="inner"
|
||||
title={
|
||||
<div className="rb:flex rb:items-center rb:gap-2">
|
||||
<span>{cat.name}</span>
|
||||
</div>
|
||||
}
|
||||
classNames={{
|
||||
body: "rb:p-[10px]!",
|
||||
header: "rb:bg-[#F6F8FC]!"
|
||||
}}
|
||||
>
|
||||
<Space size={8} direction="vertical" className="rb:w-full">
|
||||
{marketSources
|
||||
.filter(s => s.category === cat.id)
|
||||
.map(source => (
|
||||
<div
|
||||
key={source.id}
|
||||
className={`rb:bg-white rb:rounded-lg rb:p-2 rb:border rb:cursor-pointer rb:flex rb:items-center rb:gap-2 rb:transition-all ${
|
||||
selectedSource === source.id
|
||||
? 'rb:border-[#155EEF] rb:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
|
||||
: 'rb:border-[#DFE4ED] rb:hover:border-[#155EEF] rb:hover:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
|
||||
}`}
|
||||
onClick={() => handleSelectSource(source.id)}
|
||||
>
|
||||
<div className="rb:w-5 rb:h-5 rb:flex-shrink-0 rb:flex rb:items-center rb:justify-center rb:overflow-hidden rb:rounded rb:bg-gray-100">
|
||||
{source.logo_url ? (
|
||||
<img
|
||||
src={source.logo_url}
|
||||
alt={source.name}
|
||||
className="rb:w-full rb:h-full rb:object-cover"
|
||||
referrerPolicy="no-referrer"
|
||||
onError={(e) => {
|
||||
e.currentTarget.style.display = 'none';
|
||||
const parent = e.currentTarget.parentElement;
|
||||
if (parent) {
|
||||
parent.innerHTML = '🏪';
|
||||
parent.style.fontSize = '16px';
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<span className="rb:text-base">🏪</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="rb:flex-1 rb:font-medium rb:text-[12px] rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||
{source.name}
|
||||
</span>
|
||||
{/* <span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
|
||||
{source.mcp_count}
|
||||
</span> */}
|
||||
{source.connected && (
|
||||
<span className="rb:text-green-500 rb:text-[8px] rb:flex-shrink-0">●</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</Space>
|
||||
</Card>
|
||||
))}
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
{/* 右侧内容区 */}
|
||||
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden">
|
||||
<div className="rb:flex-1 rb:border-l rb:border-gray-200 rb:overflow-hidden">
|
||||
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
|
||||
{renderSourceDetail()}
|
||||
</div>
|
||||
@@ -308,6 +577,10 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
ref={marketConfigModalRef}
|
||||
onConnect={handleConnect}
|
||||
/>
|
||||
<McpServiceModal
|
||||
ref={mcpServiceModalRef}
|
||||
refresh={handleRefreshAfterAdd}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -61,7 +61,6 @@ const Mcp: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getSta
|
||||
getData()
|
||||
})
|
||||
};
|
||||
|
||||
// 删除服务
|
||||
const handleDeleteService = (item: ToolItem) => {
|
||||
if (!item.id) {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, Button, App, Space } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||
import { createMarketConfig,updateMarketConfig } from '@/api/tools';
|
||||
import RbModal from '@/components/RbModal';
|
||||
|
||||
const FormItem = Form.Item;
|
||||
@@ -9,15 +10,16 @@ const FormItem = Form.Item;
|
||||
interface MarketSource {
|
||||
id: string;
|
||||
name: string;
|
||||
icon: string;
|
||||
logo_url: string;
|
||||
url: string;
|
||||
desc: string;
|
||||
apiKey: string;
|
||||
description: string;
|
||||
token?: string;
|
||||
connected: boolean;
|
||||
configId?: string;
|
||||
}
|
||||
|
||||
interface MarketConfigModalProps {
|
||||
onConnect: (sourceId: string, apiKey: string) => void;
|
||||
onConnect: (sourceId: string, configId: string) => void;
|
||||
}
|
||||
|
||||
export interface MarketConfigModalRef {
|
||||
@@ -35,6 +37,8 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [currentSource, setCurrentSource] = useState<MarketSource | null>(null);
|
||||
const [showApiKey, setShowApiKey] = useState(false);
|
||||
const [initialValues, setInitialValues] = useState<{ token: string }>({ token: '' });
|
||||
const formValues = Form.useWatch([], form);
|
||||
|
||||
const handleClose = () => {
|
||||
setVisible(false);
|
||||
@@ -42,32 +46,62 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
||||
setLoading(false);
|
||||
setCurrentSource(null);
|
||||
setShowApiKey(false);
|
||||
setInitialValues({ token: '' });
|
||||
};
|
||||
|
||||
const handleOpen = (source: MarketSource) => {
|
||||
console.log('Modal 接收到的数据:', source);
|
||||
setCurrentSource(source);
|
||||
form.setFieldsValue({
|
||||
url: source.url,
|
||||
apiKey: source.apiKey,
|
||||
});
|
||||
setInitialValues({ token: source.token || '' });
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
const handleAfterOpenChange = (open: boolean) => {
|
||||
if (open && currentSource) {
|
||||
// Modal 完全打开后再设置表单值,使用 setTimeout 确保在下一个事件循环
|
||||
setTimeout(() => {
|
||||
form.setFieldsValue({
|
||||
token: currentSource.token || '',
|
||||
});
|
||||
console.log('Modal 打开后设置表单值:', { token: currentSource.token || '' });
|
||||
console.log('当前表单所有值:', form.getFieldsValue());
|
||||
}, 100);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
.then(async (values) => {
|
||||
if (!currentSource) return;
|
||||
|
||||
setLoading(true);
|
||||
|
||||
// 模拟连接延迟
|
||||
setTimeout(() => {
|
||||
onConnect(currentSource.id, values.apiKey || '');
|
||||
message.success(`正在连接 ${currentSource.name}...`);
|
||||
setLoading(false);
|
||||
try {
|
||||
let res: any;
|
||||
if (currentSource.configId) {
|
||||
// 更新配置
|
||||
res = await updateMarketConfig({
|
||||
mcp_market_config_id: currentSource.configId,
|
||||
token: values.token || '',
|
||||
status: 1,
|
||||
});
|
||||
message.success(t('tool.marketConfigUpdated', { name: currentSource.name }));
|
||||
} else {
|
||||
// 创建配置
|
||||
res = await createMarketConfig({
|
||||
mcp_market_id: currentSource.id || '',
|
||||
token: values.token || '',
|
||||
status: 1,
|
||||
});
|
||||
message.success(t('tool.marketConnecting', { name: currentSource.name }));
|
||||
}
|
||||
onConnect(currentSource.id, res.id || currentSource.configId);
|
||||
handleClose();
|
||||
}, 500);
|
||||
} catch (error) {
|
||||
console.error('保存配置失败:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('表单验证失败:', err);
|
||||
@@ -82,6 +116,9 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
||||
}
|
||||
};
|
||||
|
||||
// 检查是否可以保存:token 字段必须有值
|
||||
const canSave = formValues?.token?.trim().length > 0;
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
@@ -91,77 +128,97 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={`配置 ${currentSource.name}`}
|
||||
title={t('tool.marketConfig', { name: currentSource.name })}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText="保存并连接"
|
||||
afterOpenChange={handleAfterOpenChange}
|
||||
okText={t('tool.marketSaveAndConnect')}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
okButtonProps={{ disabled: !canSave }}
|
||||
width={600}
|
||||
>
|
||||
<div>
|
||||
{/* 市场源信息头部 */}
|
||||
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
|
||||
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0">
|
||||
{currentSource.icon}
|
||||
<div className="rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0 rb:overflow-hidden">
|
||||
{currentSource.logo_url ? (
|
||||
<img
|
||||
src={currentSource.logo_url}
|
||||
alt={currentSource.name}
|
||||
className="rb:w-full rb:h-full rb:object-cover"
|
||||
onError={(e) => {
|
||||
e.currentTarget.style.display = 'none';
|
||||
const parent = e.currentTarget.parentElement;
|
||||
if (parent) {
|
||||
parent.innerHTML = '🏪';
|
||||
parent.style.fontSize = '32px';
|
||||
}
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<span className="rb:text-4xl">🏪</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="rb:flex-1">
|
||||
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.description}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Form
|
||||
key={currentSource?.id || 'new'}
|
||||
form={form}
|
||||
layout="vertical"
|
||||
initialValues={initialValues}
|
||||
>
|
||||
{/* 市场地址 */}
|
||||
<FormItem
|
||||
name="url"
|
||||
label="市场地址"
|
||||
>
|
||||
<FormItem label={t('tool.marketUrl')}>
|
||||
<Space.Compact style={{ width: '100%' }}>
|
||||
<Input
|
||||
readOnly
|
||||
placeholder="市场地址"
|
||||
value={currentSource.url}
|
||||
/>
|
||||
<Button
|
||||
icon={<CopyOutlined />}
|
||||
onClick={handleCopyUrl}
|
||||
>
|
||||
复制
|
||||
{t('tool.marketCopy')}
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</FormItem>
|
||||
|
||||
{/* API Key */}
|
||||
<FormItem
|
||||
name="apiKey"
|
||||
name="token"
|
||||
label={
|
||||
<span>
|
||||
API Key <span className="rb:text-gray-400 rb:font-normal">(可选)</span>
|
||||
API Key
|
||||
</span>
|
||||
}
|
||||
extra="部分市场需要 API Key 才能获取完整的服务列表"
|
||||
rules={[
|
||||
{ required: true, message: t('tool.marketApiKeyRequired') },
|
||||
{ whitespace: true, message: t('tool.marketApiKeyRequired') }
|
||||
]}
|
||||
extra={<span style={{ display: 'inline-block', marginTop: 8 }}>{t('tool.marketApiKeyExtra')}</span>}
|
||||
>
|
||||
<Space.Compact style={{ width: '100%' }}>
|
||||
<Input
|
||||
type={showApiKey ? 'text' : 'password'}
|
||||
placeholder="输入 API Key 以获取更多服务"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<Button
|
||||
icon={showApiKey ? <EyeInvisibleOutlined /> : <EyeOutlined />}
|
||||
onClick={() => setShowApiKey(!showApiKey)}
|
||||
/>
|
||||
</Space.Compact>
|
||||
<Input
|
||||
type={showApiKey ? 'text' : 'password'}
|
||||
placeholder={t('tool.marketApiKeyPlaceholder')}
|
||||
autoComplete="off"
|
||||
suffix={
|
||||
<Button
|
||||
type="text"
|
||||
size="small"
|
||||
icon={showApiKey ? <EyeInvisibleOutlined /> : <EyeOutlined />}
|
||||
onClick={() => setShowApiKey(!showApiKey)}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</FormItem>
|
||||
|
||||
{/* 连接状态 */}
|
||||
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
|
||||
<span className="rb:text-gray-600">连接状态:</span>
|
||||
<span className="rb:text-gray-600">{t('tool.marketConnectionStatus')}:</span>
|
||||
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
|
||||
{currentSource.connected ? '● 已连接' : '○ 未连接'}
|
||||
{currentSource.connected ? t('tool.marketConnected') : t('tool.marketDisconnected')}
|
||||
</span>
|
||||
</div>
|
||||
</Form>
|
||||
|
||||
@@ -41,6 +41,7 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
const values = Form.useWatch<MCPToolItem>([], form)
|
||||
const requestHeaderModalRef = useRef<RequestHeaderModalRef>(null)
|
||||
const [requestHeaderList, setRequestHeaderList] = useState<RequestHeader[]>([])
|
||||
const abortControllerRef = useRef<AbortController | null>(null)
|
||||
|
||||
const formatTabItems = () => {
|
||||
return tabKeys.map(key => ({
|
||||
@@ -54,6 +55,12 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
|
||||
// 封装取消方法,添加关闭弹窗逻辑
|
||||
const handleClose = () => {
|
||||
// 如果有正在进行的请求,取消它
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
|
||||
setVisible(false);
|
||||
form.resetFields();
|
||||
setLoading(false);
|
||||
@@ -70,7 +77,7 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
config: { ...config_data }
|
||||
})
|
||||
|
||||
if (config_data.connection_config.headers) {
|
||||
if (config_data?.connection_config?.headers) {
|
||||
console.log(Object.keys(config_data.connection_config.headers).map(key => ({
|
||||
key,
|
||||
value: config_data.connection_config.headers[key]
|
||||
@@ -81,6 +88,16 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
})))
|
||||
}
|
||||
setEditVo(data)
|
||||
} else if (data) {
|
||||
const { config_data, name, description, icon } = data
|
||||
form.setFieldsValue({
|
||||
name, description, icon,
|
||||
...(config_data ? { config: { ...config_data } } : {})
|
||||
})
|
||||
// 如果是从 Market 组件传来的数据(包含 market_id),保存完整的 data 用于后续提交
|
||||
if ((data as any).market_id) {
|
||||
setEditVo(data)
|
||||
}
|
||||
} else {
|
||||
form.resetFields();
|
||||
}
|
||||
@@ -93,6 +110,10 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
.validateFields()
|
||||
.then(() => {
|
||||
setLoading(true);
|
||||
|
||||
// 创建 AbortController 用于取消请求
|
||||
abortControllerRef.current = new AbortController();
|
||||
|
||||
// 创建新服务对象
|
||||
const { config, ...rest } = values
|
||||
|
||||
@@ -110,17 +131,42 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
}
|
||||
}
|
||||
}
|
||||
const request = editVo?.id ? updateTool(editVo.id, newService) : addTool(newService)
|
||||
|
||||
// 如果是从 Market 组件传来的数据,添加市场相关字段
|
||||
if ((editVo as any)?.market_id) {
|
||||
(newService.config as any).source_channel = (editVo as any).source_channel;
|
||||
(newService.config as any).market_id = (editVo as any).market_id;
|
||||
(newService.config as any).market_config_id = (editVo as any).market_config_id;
|
||||
(newService.config as any).mcp_service_id = (editVo as any).mcp_service_id;
|
||||
}
|
||||
|
||||
const request = editVo?.id
|
||||
? updateTool(editVo.id, newService, { signal: abortControllerRef.current.signal })
|
||||
: addTool(newService, { signal: abortControllerRef.current.signal })
|
||||
request.then((res: any) => {
|
||||
// 清除 AbortController
|
||||
abortControllerRef.current = null;
|
||||
|
||||
message.success(t('common.saveSuccess'));
|
||||
testConnection(res.tool_id || editVo?.id)
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
handleClose();
|
||||
refresh()
|
||||
})
|
||||
setLoading(false);
|
||||
handleClose();
|
||||
refresh();
|
||||
|
||||
// 在后台测试连接,不阻塞用户操作
|
||||
testConnection(res.tool_id || editVo?.id).catch((err) => {
|
||||
console.error('测试连接失败:', err);
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
.catch((error) => {
|
||||
// 清除 AbortController
|
||||
abortControllerRef.current = null;
|
||||
|
||||
// 如果是用户主动取消,不显示错误提示
|
||||
if (error.name === 'AbortError' || error.code === 'ERR_CANCELED') {
|
||||
console.log('请求已取消');
|
||||
} else {
|
||||
message.error(t('common.saveFailed'));
|
||||
}
|
||||
setLoading(false);
|
||||
})
|
||||
})
|
||||
@@ -150,7 +196,13 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
onCancel={handleClose}
|
||||
okText={t('tool.saveAndTest')}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
okButtonProps={{ loading: loading }}
|
||||
footer={(_, { OkBtn, CancelBtn }) => (
|
||||
<>
|
||||
<CancelBtn />
|
||||
<OkBtn />
|
||||
</>
|
||||
)}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-01-05 17:22:23
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-03-06 15:08:38
|
||||
* @LastEditTime: 2026-03-06 15:11:31
|
||||
*/
|
||||
import React, { useState } from 'react';
|
||||
import { Tabs } from 'antd';
|
||||
@@ -16,7 +16,7 @@ import Custom from './Custom';
|
||||
import Market from './Market';
|
||||
import Tag from '@/components/Tag'
|
||||
|
||||
const tabKeys = ['mcp', 'inner', 'custom'] // , 'market'
|
||||
const tabKeys = ['mcp', 'inner', 'custom', 'market'] //
|
||||
const ToolManagement: React.FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const [activeTab, setActiveTab] = useState('mcp');
|
||||
@@ -54,7 +54,7 @@ const ToolManagement: React.FC = () => {
|
||||
{activeTab === 'mcp' && <Mcp getStatusTag={getStatusTag} />}
|
||||
{activeTab === 'inner' && <Inner getStatusTag={getStatusTag} />}
|
||||
{activeTab === 'custom' && <Custom getStatusTag={getStatusTag} />}
|
||||
{/* {activeTab === 'market' && <Market getStatusTag={getStatusTag} />} */}
|
||||
{activeTab === 'market' && <Market getStatusTag={getStatusTag} />}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -75,6 +75,10 @@ export interface ToolItem {
|
||||
tool_class: string;
|
||||
|
||||
schema_content: string;
|
||||
source_channel?: string;
|
||||
market_id?: string;
|
||||
market_config_id?: string;
|
||||
mcp_service_id?: string;
|
||||
};
|
||||
status: 'available' | 'unavailable';
|
||||
tags: string[];
|
||||
@@ -136,4 +140,11 @@ export interface ExecuteData {
|
||||
export interface CustomToolModalRef {
|
||||
handleOpen: (data?: ToolItem) => void;
|
||||
handleClose: () => void;
|
||||
}
|
||||
|
||||
export interface MarketQuery {
|
||||
mcp_market_config_id?: string;
|
||||
page?: number;
|
||||
pagesize?: number;
|
||||
keywords?: string;
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:57:11
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:57:11
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 18:00:11
|
||||
*/
|
||||
/**
|
||||
* RAG User Memory Detail View
|
||||
@@ -13,7 +13,8 @@
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import clsx from 'clsx'
|
||||
import { Row, Col, Skeleton } from 'antd'
|
||||
import { Row, Col, Skeleton, Spin, Flex, Tooltip } from 'antd'
|
||||
import { LoadingOutlined } from '@ant-design/icons';
|
||||
import { useParams } from 'react-router-dom'
|
||||
|
||||
import aboutUs from '@/assets/images/userMemory/aboutUs.svg'
|
||||
@@ -26,6 +27,7 @@ import {
|
||||
getUserProfile,
|
||||
getTotalRagMemoryCountByUser,
|
||||
getChunkInsight,
|
||||
generateRagProfile
|
||||
} from '@/api/memory'
|
||||
import Empty from '@/components/Empty'
|
||||
import ConversationMemory from './components/ConversationMemory'
|
||||
@@ -133,16 +135,46 @@ const Rag: FC = () => {
|
||||
})
|
||||
}
|
||||
const name = loading.detail ? '' : data?.name && data?.name !== '' ? data.name : id
|
||||
|
||||
const [refreshLoading, setRefreshLoading] = useState(false)
|
||||
const handleRefresh = () => {
|
||||
if (refreshLoading || !id) return
|
||||
setRefreshLoading(true)
|
||||
generateRagProfile(id as string)
|
||||
.then(() => {
|
||||
getSummary()
|
||||
getInsightReport()
|
||||
})
|
||||
.finally(() => {
|
||||
setRefreshLoading(false)
|
||||
})
|
||||
}
|
||||
return (
|
||||
<Row gutter={[16, 16]} className="rb:pb-6">
|
||||
<Row gutter={[16, 16]} className="rb:h-full!">
|
||||
<Col span={8}>
|
||||
<RbCard>
|
||||
<RbCard
|
||||
className="rb:h-[calc(100vh-104px)]!"
|
||||
bodyClassName="rb:overflow-y-auto! rb:h-full!"
|
||||
>
|
||||
<div className="rb:flex rb:items-center">
|
||||
<div className="rb:flex-[0_0_auto] rb:w-20 rb:h-20 rb:text-center rb:font-semibold rb:text-[28px] rb:leading-20 rb:rounded-lg rb:text-[#FBFDFF] rb:bg-[#155EEF]">{name?.[0]}</div>
|
||||
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4">
|
||||
{name}<br/>
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
|
||||
</div>
|
||||
<Flex>
|
||||
<div className="rb:text-[24px] rb:font-semibold rb:leading-8 rb:ml-4 rb:flex-1">
|
||||
{name}<br />
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 rb:mt-2">{personas?.join(' | ')}</div>
|
||||
</div>
|
||||
<Tooltip title={t('common.refresh')}>
|
||||
{refreshLoading
|
||||
? <Spin indicator={<LoadingOutlined spin />} />
|
||||
: (
|
||||
<div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/refresh.svg')] rb:hover:bg-[url('@/assets/images/refresh_hover.svg')]"
|
||||
onClick={handleRefresh}
|
||||
></div>
|
||||
)
|
||||
}
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</div>
|
||||
|
||||
<div className="rb:flex rb:gap-2 rb:mb-2 rb:flex-wrap rb:mt-6.25">
|
||||
|
||||
@@ -1,74 +1,43 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 18:34:04
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 18:34:04
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-12 18:34:52
|
||||
*/
|
||||
/**
|
||||
* Conversation Memory Component
|
||||
* Displays RAG conversation memory content list
|
||||
*/
|
||||
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { Skeleton, List } from 'antd';
|
||||
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import Empty from '@/components/Empty';
|
||||
import PageScrollList from '@/components/PageScrollList'
|
||||
import Markdown from '@/components/Markdown'
|
||||
import {
|
||||
getRagContent
|
||||
} from '@/api/memory'
|
||||
import { getRagContentUrl } from '@/api/memory'
|
||||
|
||||
const ConversationMemory:FC = () => {
|
||||
const ConversationMemory: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { id } = useParams()
|
||||
const [loading, setLoading] = useState<boolean>(true)
|
||||
const [list, setList] = useState<string[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
if (!id) return
|
||||
getList()
|
||||
}, [id])
|
||||
/** Fetch conversation memory list */
|
||||
const getList = () => {
|
||||
if (!id) return
|
||||
setLoading(true)
|
||||
getRagContent(id).then((res) => {
|
||||
setList((res as { contents?: [] }).contents || [])
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<RbCard
|
||||
<RbCard
|
||||
title={t('userMemory.conversationMemory')}
|
||||
headerClassName="rb:text-[18px]! rb:leading-[24px]"
|
||||
bodyClassName="rb:h-[100%]! rb:overflow-hidden rb:py-0!"
|
||||
bodyClassName="rb:h-[calc(100%-56px)]! rb:overflow-hidden"
|
||||
className="rb:h-[calc(100vh-104px)]!"
|
||||
>
|
||||
{loading
|
||||
? <Skeleton />
|
||||
: list.length > 0
|
||||
? <List
|
||||
dataSource={list}
|
||||
grid={{ gutter: 12, column: 1 }}
|
||||
renderItem={(item, index) => (
|
||||
<List.Item>
|
||||
<div
|
||||
key={index}
|
||||
className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:mt-2 rb:text-gray-800 rb:text-sm"
|
||||
>
|
||||
<Markdown content={item} />
|
||||
</div>
|
||||
</List.Item>
|
||||
)}
|
||||
/>
|
||||
: <Empty className="rb:h-full" />
|
||||
}
|
||||
<PageScrollList<string>
|
||||
url={getRagContentUrl}
|
||||
query={{ end_user_id: id }}
|
||||
column={1}
|
||||
renderItem={(item) => (
|
||||
<div className="rb:rounded-lg rb:border rb:border-[#DFE4ED] rb:px-4 rb:py-3 rb:bg-[#F0F3F8] rb:text-gray-800 rb:text-sm">
|
||||
<Markdown content={item} />
|
||||
</div>
|
||||
)}
|
||||
className="rb:h-full!"
|
||||
// className="rb:h-[calc(100%-24px)]!"
|
||||
/>
|
||||
</RbCard>
|
||||
)
|
||||
}
|
||||
export default ConversationMemory
|
||||
|
||||
export default ConversationMemory
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { FC } from 'react';
|
||||
import { Select } from 'antd';
|
||||
import { Select, Divider } from 'antd';
|
||||
// import { Node } from '@antv/x6';
|
||||
import type { GraphRef } from '../types'
|
||||
import { PlusOutlined, MinusOutlined } from '@ant-design/icons'
|
||||
import { PlusOutlined, MinusOutlined, FileAddOutlined } from '@ant-design/icons'
|
||||
|
||||
interface CanvasToolbarProps {
|
||||
miniMapRef: React.RefObject<HTMLDivElement>;
|
||||
@@ -14,6 +14,7 @@ interface CanvasToolbarProps {
|
||||
canRedo: boolean;
|
||||
onUndo: () => void;
|
||||
onRedo: () => void;
|
||||
addNotes: () => void;
|
||||
}
|
||||
|
||||
const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
@@ -26,6 +27,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
// canRedo,
|
||||
// onUndo,
|
||||
// onRedo,
|
||||
addNotes,
|
||||
}) => {
|
||||
// 整理布局函数
|
||||
/*
|
||||
@@ -152,7 +154,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
{/* 小地图 */}
|
||||
<div ref={miniMapRef} className="rb:absolute rb:bottom-15 rb:right-8 rb:z-1000 rb:rounded-lg rb:overflow-hidden"></div>
|
||||
{/* 缩放控制按钮 */}
|
||||
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:gap-4 rb:z-1000">
|
||||
<div className="rb:h-8.5 rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:shadow-[0px_2px_6px_0px_rgba(33,35,50,0.15)] rb:px-3 rb:py-2 rb:absolute rb:bottom-5 rb:right-8 rb:flex rb:flex-row rb:items-center rb:gap-4 rb:z-1000">
|
||||
<MinusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(-0.1)} />
|
||||
<Select
|
||||
value={Math.round(zoomLevel * 100)}
|
||||
@@ -182,6 +184,8 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
||||
size="small"
|
||||
/>
|
||||
<PlusOutlined className="rb:text-[16px] rb:cursor-pointer" onClick={() => graphRef.current?.zoom(0.1)} />
|
||||
<Divider type="vertical" className="rb:h-4" />
|
||||
<FileAddOutlined onClick={addNotes} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { FORMAT_TEXT_COMMAND, $getSelection, $isRangeSelection, $setSelection, $isTextNode, type BaseSelection } from 'lexical';
|
||||
import { $patchStyleText } from '@lexical/selection';
|
||||
import { INSERT_UNORDERED_LIST_COMMAND, REMOVE_LIST_COMMAND, ListNode } from '@lexical/list';
|
||||
import { TOGGLE_LINK_COMMAND, LinkNode } from '@lexical/link';
|
||||
import { $getNearestNodeOfType } from '@lexical/utils';
|
||||
|
||||
export const NOTE_FORMAT_EVENT = 'note:format';
|
||||
|
||||
export interface FormatState {
|
||||
bold: boolean;
|
||||
italic: boolean;
|
||||
strikethrough: boolean;
|
||||
list: boolean;
|
||||
fontSize?: number;
|
||||
linkUrl?: string | null;
|
||||
}
|
||||
|
||||
const NoteFormatPlugin = ({ nodeId, onFormatChange, fontSize = 12 }: { nodeId: string; fontSize?: number; onFormatChange?: (state: FormatState) => void }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const savedSelection = useRef<BaseSelection | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerUpdateListener(({ editorState }) => {
|
||||
editorState.read(() => {
|
||||
const selection = $getSelection();
|
||||
if (!$isRangeSelection(selection)) return;
|
||||
savedSelection.current = selection.clone();
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const style = 'getStyle' in anchorNode ? (anchorNode as { getStyle(): string }).getStyle() : '';
|
||||
const match = style.match(/font-size:\s*([\d.]+)px/);
|
||||
const nodeFontSize = match ? Number(match[1]) : fontSize;
|
||||
const linkNode = $getNearestNodeOfType(anchorNode, LinkNode);
|
||||
onFormatChange?.({
|
||||
bold: selection.hasFormat('bold'),
|
||||
italic: selection.hasFormat('italic'),
|
||||
strikethrough: selection.hasFormat('strikethrough'),
|
||||
list: !!$getNearestNodeOfType(anchorNode, ListNode),
|
||||
...(nodeFontSize ? { fontSize: nodeFontSize } : {}),
|
||||
linkUrl: linkNode ? linkNode.getURL() : null,
|
||||
});
|
||||
});
|
||||
});
|
||||
}, [editor, onFormatChange]);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: Event) => {
|
||||
const { id, format, value } = (e as CustomEvent).detail;
|
||||
if (id !== nodeId) return;
|
||||
const sel = savedSelection.current;
|
||||
const hasSelection = $isRangeSelection(sel) && !sel.isCollapsed();
|
||||
if (format === 'link' && value === null) {
|
||||
// remove link: select the entire LinkNode first
|
||||
editor.focus(() => {
|
||||
editor.update(() => {
|
||||
const s = $getSelection();
|
||||
const anchorNode = $isRangeSelection(s)
|
||||
? s.anchor.getNode()
|
||||
: savedSelection.current && $isRangeSelection(savedSelection.current)
|
||||
? savedSelection.current.anchor.getNode()
|
||||
: null;
|
||||
const linkNode = anchorNode ? $getNearestNodeOfType(anchorNode, LinkNode) : null;
|
||||
if (linkNode) {
|
||||
const children = linkNode.getChildren();
|
||||
if (children.length > 0) {
|
||||
const first = children[0];
|
||||
const last = children[children.length - 1];
|
||||
if ($isTextNode(first) && $isTextNode(last)) {
|
||||
const range = first.select(0, 0);
|
||||
range.focus.set(last.getKey(), last.getTextContentSize(), 'text');
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
editor.dispatchCommand(TOGGLE_LINK_COMMAND, null);
|
||||
});
|
||||
} else if (format === 'list') {
|
||||
editor.focus(() => {
|
||||
if (sel) editor.update(() => $setSelection(sel));
|
||||
editor.dispatchCommand(value ? INSERT_UNORDERED_LIST_COMMAND : REMOVE_LIST_COMMAND, undefined);
|
||||
editor.update(() => $setSelection(null));
|
||||
});
|
||||
} else if (hasSelection) {
|
||||
editor.focus(() => {
|
||||
editor.update(() => $setSelection(sel));
|
||||
if (format === 'bold' || format === 'italic' || format === 'strikethrough') {
|
||||
editor.dispatchCommand(FORMAT_TEXT_COMMAND, format);
|
||||
} else if (format === 'link') {
|
||||
editor.dispatchCommand(TOGGLE_LINK_COMMAND, value as string | null);
|
||||
} else if (format === 'fontSize') {
|
||||
editor.update(() => {
|
||||
$setSelection(sel);
|
||||
$patchStyleText(sel!, { 'font-size': `${value}px` });
|
||||
});
|
||||
}
|
||||
editor.update(() => $setSelection(null));
|
||||
});
|
||||
}
|
||||
};
|
||||
window.addEventListener(NOTE_FORMAT_EVENT, handler);
|
||||
return () => window.removeEventListener(NOTE_FORMAT_EVENT, handler);
|
||||
}, [editor, nodeId]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default NoteFormatPlugin;
|
||||
@@ -0,0 +1,74 @@
|
||||
import { type FC, useState } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Flex, Button, Input } from 'antd';
|
||||
import { EditOutlined, DisconnectOutlined } from '@ant-design/icons';
|
||||
|
||||
const POPOVER_STYLE: React.CSSProperties = {
|
||||
position: 'fixed',
|
||||
zIndex: 1000,
|
||||
background: '#fff',
|
||||
border: '1px solid #e5e7eb',
|
||||
borderRadius: 8,
|
||||
boxShadow: '0 2px 8px rgba(0,0,0,0.12)',
|
||||
whiteSpace: 'nowrap',
|
||||
};
|
||||
|
||||
interface LinkPopoverProps {
|
||||
url: string;
|
||||
rect: DOMRect;
|
||||
onEdit: () => void;
|
||||
onRemove: () => void;
|
||||
}
|
||||
|
||||
export const LinkPopover: FC<LinkPopoverProps> = ({ url, rect, onEdit, onRemove }) => {
|
||||
const { t } = useTranslation();
|
||||
return createPortal(
|
||||
<div
|
||||
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '4px 10px', fontSize: 12 }}
|
||||
onMouseDown={e => e.stopPropagation()}
|
||||
>
|
||||
<Flex align="center" gap={8}>
|
||||
<a href={url} target="_blank" rel="noreferrer" style={{ color: '#2563eb', maxWidth: 160, overflow: 'hidden', textOverflow: 'ellipsis', display: 'inline-block' }}>
|
||||
{url}
|
||||
</a>
|
||||
<Button size="small" type="text" icon={<EditOutlined />} onClick={onEdit}>{t('common.edit')}</Button>
|
||||
<Button size="small" type="text" icon={<DisconnectOutlined />} onClick={onRemove}>{t('workflow.config.notes.removeLink')}</Button>
|
||||
</Flex>
|
||||
</div>,
|
||||
document.body
|
||||
);
|
||||
};
|
||||
|
||||
interface EditLinkPopoverProps {
|
||||
rect: DOMRect;
|
||||
initialUrl: string;
|
||||
onConfirm: (url: string) => void;
|
||||
}
|
||||
|
||||
export const EditLinkPopover: FC<EditLinkPopoverProps> = ({ rect, initialUrl, onConfirm }) => {
|
||||
const { t } = useTranslation();
|
||||
const [url, setUrl] = useState(initialUrl);
|
||||
const confirm = () => onConfirm(url);
|
||||
return createPortal(
|
||||
<div
|
||||
style={{ ...POPOVER_STYLE, left: rect.left, top: rect.bottom + 4, padding: '8px' }}
|
||||
onMouseDown={e => e.stopPropagation()}
|
||||
>
|
||||
<Flex gap={8}>
|
||||
<Input
|
||||
size="small"
|
||||
className="rb:w-60!"
|
||||
placeholder={t('workflow.config.notes.enterLink')}
|
||||
value={url}
|
||||
onChange={e => setUrl(e.target.value)}
|
||||
onKeyDown={e => e.stopPropagation()}
|
||||
onPressEnter={confirm}
|
||||
autoFocus
|
||||
/>
|
||||
<Button size="small" type="primary" onClick={confirm}>{t('common.confirm')}</Button>
|
||||
</Flex>
|
||||
</div>,
|
||||
document.body
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,184 @@
|
||||
import { type FC, useState, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { LexicalComposer } from '@lexical/react/LexicalComposer';
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin';
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
|
||||
import { ListPlugin } from '@lexical/react/LexicalListPlugin';
|
||||
import { LinkPlugin } from '@lexical/react/LexicalLinkPlugin';
|
||||
import { ListNode, ListItemNode } from '@lexical/list';
|
||||
import { LinkNode } from '@lexical/link';
|
||||
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { useEffect, useRef } from 'react';
|
||||
import NoteFormatPlugin from './NoteFormatPlugin';
|
||||
import type { FormatState } from './NoteFormatPlugin';
|
||||
import { LinkPopover, EditLinkPopover } from './NoteLinkPopovers';
|
||||
|
||||
const theme = {
|
||||
paragraph: 'editor-paragraph',
|
||||
text: {
|
||||
bold: 'editor-text-bold',
|
||||
italic: 'editor-text-italic',
|
||||
strikethrough: 'note-text-strikethrough',
|
||||
},
|
||||
list: { ul: 'note-list-ul', listitem: 'note-list-item' },
|
||||
link: 'note-link',
|
||||
};
|
||||
|
||||
const NOTE_NODES = [ListNode, ListItemNode, LinkNode];
|
||||
|
||||
const NOTE_STYLES = `
|
||||
.editor-text-bold { font-weight: bold; }
|
||||
.editor-text-italic { font-style: italic; }
|
||||
.note-text-strikethrough { text-decoration: line-through; }
|
||||
.note-list-ul { list-style-type: disc; padding-left: 1.2em; margin: 0; }
|
||||
.note-list-item { margin: 2px 0; }
|
||||
.note-link { color: #2563eb; text-decoration: underline; cursor: pointer; }
|
||||
`;
|
||||
|
||||
const NoteInitPlugin: FC<{ value: string }> = ({ value }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const initialized = useRef(false);
|
||||
useEffect(() => {
|
||||
if (initialized.current || !value) return;
|
||||
initialized.current = true;
|
||||
try {
|
||||
const parsed = JSON.parse(value);
|
||||
if (parsed?.root) {
|
||||
const state = editor.parseEditorState(JSON.stringify(parsed));
|
||||
editor.setEditorState(state);
|
||||
return;
|
||||
}
|
||||
} catch {}
|
||||
}, [editor, value]);
|
||||
return null;
|
||||
};
|
||||
|
||||
|
||||
interface NoteEditorProps {
|
||||
nodeId: string;
|
||||
value: string;
|
||||
fontSize?: number;
|
||||
onChange: (val: string) => void;
|
||||
onFormatChange?: (state: FormatState) => void;
|
||||
}
|
||||
|
||||
const NoteEditor: FC<NoteEditorProps> = ({ nodeId, value, fontSize = 12, onChange, onFormatChange }) => {
|
||||
const { t } = useTranslation();
|
||||
const [linkState, setLinkState] = useState<{ url: string; rect: DOMRect } | null>(null);
|
||||
const [editLinkRect, setEditLinkRect] = useState<{ url: string; rect: DOMRect } | null>(null);
|
||||
const removingLink = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!linkState) return;
|
||||
const handler = () => setLinkState(null);
|
||||
window.addEventListener('mousedown', handler);
|
||||
return () => window.removeEventListener('mousedown', handler);
|
||||
}, [!!linkState]);
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: Event) => {
|
||||
const { id, url, rect: passedRect } = (e as CustomEvent).detail;
|
||||
if (id !== nodeId) return;
|
||||
if (passedRect) {
|
||||
setEditLinkRect({ url: url || '', rect: passedRect });
|
||||
return;
|
||||
}
|
||||
const sel = window.getSelection();
|
||||
if (sel && sel.rangeCount > 0) {
|
||||
const r = sel.getRangeAt(0).getBoundingClientRect();
|
||||
if (r.width > 0 || r.height > 0) { setEditLinkRect({ url: url || '', rect: r }); return; }
|
||||
}
|
||||
const linkEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
|
||||
const rect = linkEl?.getBoundingClientRect() ?? new DOMRect(window.innerWidth / 2, 200, 0, 0);
|
||||
setEditLinkRect({ url: url || '', rect });
|
||||
};
|
||||
window.addEventListener('note:edit-link', handler);
|
||||
return () => window.removeEventListener('note:edit-link', handler);
|
||||
}, [nodeId]);
|
||||
|
||||
const handleFormatChange = useCallback((state: FormatState) => {
|
||||
onFormatChange?.(state);
|
||||
if (state.linkUrl) {
|
||||
requestAnimationFrame(() => {
|
||||
if (removingLink.current) { removingLink.current = false; return; }
|
||||
const sel = window.getSelection();
|
||||
if (sel && sel.rangeCount > 0) {
|
||||
const rect = sel.getRangeAt(0).getBoundingClientRect();
|
||||
if (rect.width > 0 || rect.height > 0) {
|
||||
setLinkState({ url: state.linkUrl!, rect });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// fallback: find the link element in the correct editor
|
||||
const editorEl = document.querySelector(`[data-note-id="${nodeId}"] a.note-link`) as HTMLElement;
|
||||
if (editorEl) {
|
||||
setLinkState({ url: state.linkUrl!, rect: editorEl.getBoundingClientRect() });
|
||||
}
|
||||
});
|
||||
} else {
|
||||
setLinkState(null);
|
||||
}
|
||||
}, [onFormatChange]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{NOTE_STYLES}</style>
|
||||
<LexicalComposer initialConfig={{ namespace: `note-${nodeId}`, theme, nodes: NOTE_NODES, onError: console.error }}>
|
||||
<div style={{ position: 'relative' }} data-note-id={nodeId}>
|
||||
<RichTextPlugin
|
||||
contentEditable={
|
||||
<ContentEditable
|
||||
style={{ minHeight: 60, outline: 'none', resize: 'none', fontSize: '12px', lineHeight: '18px', color: '#374151', overflow: 'auto', cursor: 'auto' }}
|
||||
/>
|
||||
}
|
||||
placeholder={
|
||||
<div style={{ position: 'absolute', top: 0, left: 0, color: '#9CA3AF', lineHeight: '18px', pointerEvents: 'none' }}>
|
||||
{t('workflow.config.notes.placeholder')}
|
||||
</div>
|
||||
}
|
||||
ErrorBoundary={LexicalErrorBoundary}
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<ListPlugin />
|
||||
<LinkPlugin />
|
||||
<OnChangePlugin onChange={(editorState) => onChange(JSON.stringify(editorState.toJSON()))} />
|
||||
<NoteInitPlugin value={value} />
|
||||
<NoteFormatPlugin nodeId={nodeId} fontSize={fontSize} onFormatChange={handleFormatChange} />
|
||||
{editLinkRect && (
|
||||
<EditLinkPopover
|
||||
rect={editLinkRect.rect}
|
||||
initialUrl={editLinkRect.url}
|
||||
onConfirm={(url) => {
|
||||
removingLink.current = true;
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: url || null } }));
|
||||
setEditLinkRect(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{linkState && (
|
||||
<LinkPopover
|
||||
url={linkState.url}
|
||||
rect={linkState.rect}
|
||||
onEdit={() => {
|
||||
removingLink.current = true;
|
||||
const { rect, url } = linkState;
|
||||
setLinkState(null);
|
||||
setEditLinkRect({ url, rect });
|
||||
}}
|
||||
onRemove={() => {
|
||||
removingLink.current = true;
|
||||
setLinkState(null);
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: null } }));
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</LexicalComposer>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteEditor;
|
||||
@@ -0,0 +1,163 @@
|
||||
import { type FC } from 'react';
|
||||
import { Flex, Dropdown, type MenuProps, Switch, Button, Divider } from 'antd';
|
||||
import { UnorderedListOutlined, BoldOutlined, ItalicOutlined, StrikethroughOutlined, LinkOutlined, DashOutlined } from '@ant-design/icons';
|
||||
import { Node } from '@antv/x6';
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { THEME_MAP } from '../../../constant';
|
||||
const FONT_SIZES = [
|
||||
{ label: '小', value: 12 },
|
||||
{ label: '中', value: 14 },
|
||||
{ label: '大', value: 16 },
|
||||
];
|
||||
|
||||
interface NoteNodeToolbarProps {
|
||||
node: Node;
|
||||
onFormat: (type: string, value?: unknown) => void;
|
||||
toolConfig: Record<string, number | boolean>;
|
||||
nodeId: string;
|
||||
}
|
||||
|
||||
const NoteNodeToolbar: FC<NoteNodeToolbarProps> = ({ node, onFormat, toolConfig, nodeId }) => {
|
||||
const data = node?.getData() || {};
|
||||
const { t } = useTranslation();
|
||||
|
||||
const colorItems: MenuProps['items'] = Object.entries(THEME_MAP).map(([key, theme]) => ({
|
||||
key,
|
||||
label: (
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
|
||||
style={{ background: theme.bg }}
|
||||
onClick={() => onFormat('color', key)}
|
||||
/>
|
||||
),
|
||||
}));
|
||||
|
||||
const fontSizeItems: MenuProps['items'] = FONT_SIZES.map(({ label, value }) => ({
|
||||
key: value,
|
||||
label: <span onClick={() => onFormat('fontSize', value)}>{label}</span>,
|
||||
}));
|
||||
|
||||
const currentFontSize = FONT_SIZES.find(f => f.value === toolConfig.fontSize)?.label ?? '小';
|
||||
|
||||
const handleClick: MenuProps['onClick'] = (e) => {
|
||||
switch (e.key) {
|
||||
case 'delete':
|
||||
node.remove()
|
||||
break;
|
||||
case 'copy':
|
||||
break;
|
||||
}
|
||||
}
|
||||
const handleChange = (type: string) => {
|
||||
let show_author = data.config.show_author.defaultValue
|
||||
if(type === 'showAuth'){
|
||||
show_author = !show_author
|
||||
}
|
||||
node.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
show_author: {
|
||||
...data.config.show_author,
|
||||
defaultValue: show_author
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
align="center"
|
||||
gap={8}
|
||||
className="rb:absolute rb:-top-11 rb:left-1/2 rb:-translate-x-1/2 rb:bg-white rb:z-10 rb:whitespace-nowrap rb:rounded-lg rb:py-1! rb:px-3!"
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
{/* Color picker */}
|
||||
<Dropdown menu={{ items: colorItems }} trigger={['click']}>
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:rounded-full rb:cursor-pointer rb:border rb:border-gray-200"
|
||||
style={{ background: THEME_MAP[data.bgColor]?.bg || THEME_MAP.blue.bg }}
|
||||
/>
|
||||
</Dropdown>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
{/* Font size */}
|
||||
<Dropdown menu={{ items: fontSizeItems }} trigger={['click']}>
|
||||
<Flex align="center" gap={4} className="rb:cursor-pointer rb:text-xs rb:text-gray-600 rb:select-none">
|
||||
<span className="rb:text-xs">Aa</span>
|
||||
<span className="rb:text-xs">{currentFontSize}</span>
|
||||
</Flex>
|
||||
</Dropdown>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
{/* Bold */}
|
||||
<Button
|
||||
type={toolConfig.bold ? 'primary' : 'text'}
|
||||
icon={<BoldOutlined />}
|
||||
onClick={() => onFormat('bold')}
|
||||
/>
|
||||
|
||||
{/* Italic */}
|
||||
<Button
|
||||
type={toolConfig.italic ? 'primary' : 'text'}
|
||||
icon={<ItalicOutlined />}
|
||||
onClick={() => onFormat('italic')}
|
||||
/>
|
||||
|
||||
{/* Strikethrough */}
|
||||
<Button
|
||||
type={toolConfig.strikethrough ? 'primary' : 'text'}
|
||||
icon={<StrikethroughOutlined />}
|
||||
onClick={() => onFormat('strikethrough')}
|
||||
/>
|
||||
|
||||
{/* Link */}
|
||||
<Button
|
||||
type={toolConfig.link ? 'primary' : 'text'}
|
||||
icon={<LinkOutlined />}
|
||||
onClick={() => {
|
||||
const sel = window.getSelection();
|
||||
const rect = sel && sel.rangeCount > 0 ? sel.getRangeAt(0).getBoundingClientRect() : undefined;
|
||||
window.dispatchEvent(new CustomEvent('note:edit-link', { detail: { id: nodeId, url: '', rect } }));
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* List */}
|
||||
<Button
|
||||
type={toolConfig.list ? 'primary' : 'text'}
|
||||
icon={<UnorderedListOutlined />}
|
||||
onClick={() => onFormat('list')}
|
||||
/>
|
||||
|
||||
<Divider type="vertical" />
|
||||
|
||||
<Dropdown
|
||||
menu={{
|
||||
items: [
|
||||
// { key: 'copy', label: t('common.copy') },
|
||||
{
|
||||
key: 'showAuth',
|
||||
label: <Flex align="center" gap={24}>
|
||||
{t('workflow.config.notes.showAuth')}
|
||||
<Switch
|
||||
size="small"
|
||||
checked={data.config.show_author.defaultValue}
|
||||
onChange={() => handleChange('showAuth')}
|
||||
/>
|
||||
</Flex>
|
||||
},
|
||||
{ key: 'delete', label: <Flex>{t('common.delete')}</Flex> },
|
||||
],
|
||||
onClick: handleClick
|
||||
}}
|
||||
>
|
||||
<DashOutlined />
|
||||
</Dropdown>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteNodeToolbar;
|
||||
155
web/src/views/Workflow/components/Nodes/NoteNode/index.tsx
Normal file
155
web/src/views/Workflow/components/Nodes/NoteNode/index.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
import { useRef, useState } from 'react';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
import { Flex } from 'antd';
|
||||
|
||||
import NoteEditor from './NoteEditor';
|
||||
import NoteNodeToolbar from './NoteNodeToolbar';
|
||||
import { THEME_MAP } from '../../../constant'
|
||||
|
||||
const MIN_W = 240;
|
||||
const MIN_H = 120;
|
||||
|
||||
const NoteNode: ReactShapeConfig['component'] = ({ node }) => {
|
||||
const data = node?.getData() || {};
|
||||
const nodeId = node?.id || '';
|
||||
const startRef = useRef<{ x: number; y: number; w: number; h: number } | null>(null);
|
||||
const [toolConfig, setToolConfig] = useState({
|
||||
fontSize: 12,
|
||||
bold: false,
|
||||
italic: false,
|
||||
strikethrough: false,
|
||||
list: false,
|
||||
})
|
||||
|
||||
const handleFormat = (type: string, value?: unknown) => {
|
||||
console.log('handleFormat', type, value)
|
||||
if (type === 'color') {
|
||||
node?.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
theme: {
|
||||
...data.config.theme,
|
||||
defaultValue: value
|
||||
}
|
||||
}
|
||||
});
|
||||
} else if (type === 'fontSize') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'fontSize', value } }));
|
||||
} else if (type === 'link') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'link', value: value || null } }));
|
||||
} else if (type === 'list') {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: 'list', value: !toolConfig.list } }));
|
||||
} else {
|
||||
window.dispatchEvent(new CustomEvent('note:format', { detail: { id: nodeId, format: type } }));
|
||||
}
|
||||
|
||||
setToolConfig(prev => ({ ...prev, [type]: value || !prev[type as unknown as keyof typeof toolConfig] }))
|
||||
};
|
||||
|
||||
const onResizeMouseDown = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
const size = node?.getSize();
|
||||
if (!size) return;
|
||||
startRef.current = { x: e.clientX, y: e.clientY, w: size.width, h: size.height };
|
||||
|
||||
const onMouseMove = (ev: MouseEvent) => {
|
||||
if (!startRef.current) return;
|
||||
const w = Math.max(MIN_W, startRef.current.w + ev.clientX - startRef.current.x);
|
||||
const h = Math.max(MIN_H, startRef.current.h + ev.clientY - startRef.current.y);
|
||||
|
||||
node?.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
width: {
|
||||
...data.config.width,
|
||||
defaultValue: w
|
||||
},
|
||||
height: {
|
||||
...data.config.height,
|
||||
defaultValue: h
|
||||
}
|
||||
}
|
||||
});
|
||||
node?.prop('size', { width: w, height: h });
|
||||
};
|
||||
const onMouseUp = () => {
|
||||
startRef.current = null;
|
||||
window.removeEventListener('mousemove', onMouseMove);
|
||||
window.removeEventListener('mouseup', onMouseUp);
|
||||
};
|
||||
window.addEventListener('mousemove', onMouseMove);
|
||||
window.addEventListener('mouseup', onMouseUp);
|
||||
};
|
||||
|
||||
const updateText = (value: string) => {
|
||||
node.setData({
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
text: {
|
||||
...data.config.text,
|
||||
defaultValue: value
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const theme = THEME_MAP[data.config?.theme?.defaultValue || 'blue'] || THEME_MAP['blue']
|
||||
|
||||
return (
|
||||
<div
|
||||
className="rb:relative rb:h-full rb:w-full rb:rounded-2xl rb:border"
|
||||
style={{
|
||||
background: theme.bg,
|
||||
borderColor: data.isSelected ? theme.outer : theme.border,
|
||||
}}
|
||||
>
|
||||
<div className="rb:h-4 rb:rounded-tl-2xl rb:rounded-tr-2xl"
|
||||
style={{
|
||||
background: theme.title
|
||||
}}
|
||||
></div>
|
||||
{data.isSelected && <NoteNodeToolbar node={node!} nodeId={nodeId} toolConfig={toolConfig} onFormat={handleFormat} />}
|
||||
|
||||
<div
|
||||
className="rb:w-full rb:h-[calc(100%-36px)] rb:p-2.5 rb:overflow-auto"
|
||||
onMouseDown={e => {
|
||||
e.stopPropagation()
|
||||
node?.setData({ ...node.getData(), isSelected: true })
|
||||
}}
|
||||
onWheel={e => e.stopPropagation()}
|
||||
>
|
||||
<NoteEditor
|
||||
nodeId={nodeId}
|
||||
value={data.config.text.defaultValue || ''}
|
||||
fontSize={toolConfig.fontSize}
|
||||
onChange={updateText}
|
||||
onFormatChange={(state) => setToolConfig(prev => ({ ...prev, ...state }))}
|
||||
/>
|
||||
</div>
|
||||
<Flex align="center" justify="space-between" className="rb:pl-2.5! rb:pr-1!">
|
||||
<div className="rb:text-[12px] rb:text-[#5B6167]">
|
||||
{data.config.show_author.defaultValue
|
||||
? data.config.author.defaultValue
|
||||
: undefined
|
||||
}
|
||||
</div>
|
||||
|
||||
|
||||
{/* <div className="rb:size-4 rb:border-b-[4px] rb:border-r-[4px] rb:border-[#EBEBEB] rb:rounded-2xl"></div> */}
|
||||
<div
|
||||
onMouseDown={onResizeMouseDown}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 18 18" fill="none">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M12 9.75V6H13.5V9.75C13.5 11.8211 11.8211 13.5 9.75 13.5H6V12H9.75C10.9926 12 12 10.9926 12 9.75Z" fill="black" fillOpacity="0.16"></path>
|
||||
</svg>
|
||||
</div>
|
||||
</Flex>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default NoteNode;
|
||||
@@ -2,13 +2,14 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 15:06:18
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-11 12:07:20
|
||||
* @Last Modified time: 2026-03-09 13:41:19
|
||||
*/
|
||||
import LoopNode from './components/Nodes/LoopNode';
|
||||
import NormalNode from './components/Nodes/NormalNode';
|
||||
import ConditionNode from './components/Nodes/ConditionNode';
|
||||
import GroupStartNode from './components/Nodes/GroupStartNode';
|
||||
import AddNode from './components/Nodes/AddNode'
|
||||
import NoteNode from './components/Nodes/NoteNode';
|
||||
import type { PortMetadata, GroupMetadata } from '@antv/x6/lib/model/port';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
|
||||
@@ -525,6 +526,73 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
// ]
|
||||
// },
|
||||
];
|
||||
|
||||
export const THEME_MAP: Record<string, { outer: string; title: string; bg: string; border: string }> = {
|
||||
blue: {
|
||||
outer: '#2E90FA',
|
||||
title: '#D1E9FF',
|
||||
bg: '#EFF8FF',
|
||||
border: '#84CAFF',
|
||||
},
|
||||
cyan: {
|
||||
outer: '#06AED4',
|
||||
title: '#CFF9FE',
|
||||
bg: '#ECFDFF',
|
||||
border: '#67E3F9',
|
||||
},
|
||||
green: {
|
||||
outer: '#16B364',
|
||||
title: '#D3F8DF',
|
||||
bg: '#EDFCF2',
|
||||
border: '#73E2A3',
|
||||
},
|
||||
yellow: {
|
||||
outer: '#EAAA08',
|
||||
title: '#FEF7C3',
|
||||
bg: '#FEFBE8',
|
||||
border: '#FDE272',
|
||||
},
|
||||
pink: {
|
||||
outer: '#EE46BC',
|
||||
title: '#FCE7F6',
|
||||
bg: '#FDF2FA',
|
||||
border: '#FAA7E0',
|
||||
},
|
||||
violet: {
|
||||
outer: '#875BF7',
|
||||
title: '#ECE9FE',
|
||||
bg: '#F5F3FF',
|
||||
border: '#C3B5FD',
|
||||
},
|
||||
}
|
||||
|
||||
export const notesConfig = {
|
||||
type: "notes", icon: templateRenderingIcon,
|
||||
config: {
|
||||
text: {
|
||||
type: 'define',
|
||||
},
|
||||
theme: {
|
||||
type: 'define',
|
||||
defaultValue: 'blue',
|
||||
},
|
||||
width: {
|
||||
type: 'define',
|
||||
width: 240,
|
||||
},
|
||||
height: {
|
||||
type: 'define',
|
||||
height: 120,
|
||||
},
|
||||
author: {
|
||||
type: 'define',
|
||||
},
|
||||
show_author: {
|
||||
type: 'define',
|
||||
defaultValue: true
|
||||
}
|
||||
}
|
||||
}
|
||||
export const unknownNode = {
|
||||
type: 'unknown',
|
||||
icon: unknownIcon
|
||||
@@ -576,6 +644,12 @@ export const nodeRegisterLibrary: ReactShapeConfig[] = [
|
||||
height: 44,
|
||||
component: AddNode,
|
||||
},
|
||||
{
|
||||
shape: 'notes-node',
|
||||
width: nodeWidth,
|
||||
height: 120,
|
||||
component: NoteNode,
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
@@ -801,6 +875,11 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
||||
groups: {left: { position: 'left', markup: portMarkup, attrs: portAttrs }},
|
||||
items: [{ group: 'left' }],
|
||||
},
|
||||
},
|
||||
notes: {
|
||||
width: nodeWidth,
|
||||
height: 120,
|
||||
shape: 'notes-node',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '
|
||||
import { register } from '@antv/x6-react-shape';
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode } from '../constant';
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode, notesConfig } from '../constant';
|
||||
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||
import { useUser } from '@/store/user';
|
||||
|
||||
/**
|
||||
* Props for useWorkflowGraph hook
|
||||
@@ -64,6 +65,8 @@ export interface UseWorkflowGraphReturn {
|
||||
chatVariables: ChatVariable[];
|
||||
/** Function to update chat variables */
|
||||
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
||||
|
||||
handleAddNotes: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -80,6 +83,7 @@ export const useWorkflowGraph = ({
|
||||
const { id } = useParams();
|
||||
const { message } = App.useApp();
|
||||
const { t } = useTranslation()
|
||||
const { user } = useUser();
|
||||
|
||||
// Refs
|
||||
const graphRef = useRef<Graph>();
|
||||
@@ -128,7 +132,7 @@ export const useWorkflowGraph = ({
|
||||
if (nodes.length) {
|
||||
const nodeList = nodes.map(node => {
|
||||
const { id, type, name, position, config = {} } = node
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, noteNode] }]
|
||||
let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, notesConfig] }]
|
||||
.flatMap(category => category.nodes)
|
||||
.find(n => n.type === type)
|
||||
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
|
||||
@@ -197,6 +201,13 @@ export const useWorkflowGraph = ({
|
||||
data: { ...node, ...nodeLibraryConfig},
|
||||
...position,
|
||||
}
|
||||
|
||||
if (type === 'notes') {
|
||||
const w = config.width;
|
||||
const h = config.height;
|
||||
if (w) nodeConfig.width = w as number;
|
||||
if (h) nodeConfig.height = h as number;
|
||||
}
|
||||
|
||||
// Generate ports dynamically for if-else node based on cases
|
||||
if (type === 'if-else' && config.cases && Array.isArray(config.cases)) {
|
||||
@@ -461,11 +472,12 @@ export const useWorkflowGraph = ({
|
||||
*/
|
||||
const nodeClick = ({ node }: { node: Node }) => {
|
||||
// Ignore add-node type node clicks
|
||||
if (node.getData()?.type === 'add-node' || node.getData().type === 'break' || node.getData().type === 'cycle-start') {
|
||||
const nodeData = node.getData()
|
||||
if (nodeData?.type === 'add-node' || nodeData.type === 'break' || nodeData.type === 'cycle-start') {
|
||||
setSelectedNode(null)
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const nodes = graphRef.current?.getNodes();
|
||||
|
||||
nodes?.forEach(vo => {
|
||||
@@ -478,10 +490,12 @@ export const useWorkflowGraph = ({
|
||||
}
|
||||
});
|
||||
node.setData({
|
||||
...node.getData(),
|
||||
...nodeData,
|
||||
isSelected: true,
|
||||
});
|
||||
setSelectedNode(node);
|
||||
if (nodeData.type !== 'notes') {
|
||||
setSelectedNode(node);
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Handle edge click event
|
||||
@@ -859,8 +873,31 @@ export const useWorkflowGraph = ({
|
||||
init();
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
|
||||
const handleNoteKeydown = (e: KeyboardEvent) => {
|
||||
if (!graphRef.current) return;
|
||||
const selectedNote = graphRef.current.getNodes().find(n => n.getData()?.isSelected && n.getData()?.type === 'notes');
|
||||
if (!selectedNote) return;
|
||||
const isMeta = e.ctrlKey || e.metaKey;
|
||||
if (e.key === 'Delete' || e.key === 'Backspace') {
|
||||
// Only delete node when editor is not focused on text
|
||||
const active = document.activeElement;
|
||||
if (active && (active as HTMLElement).isContentEditable) return;
|
||||
deleteEvent();
|
||||
} else if (isMeta && e.key === 'c') {
|
||||
copyEvent();
|
||||
} else if (isMeta && e.key === 'v') {
|
||||
parseEvent();
|
||||
} else if (isMeta && e.key === 'd') {
|
||||
e.preventDefault();
|
||||
deleteEvent();
|
||||
}
|
||||
};
|
||||
window.addEventListener('keydown', handleNoteKeydown);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
window.removeEventListener('keydown', handleNoteKeydown);
|
||||
graphRef.current?.dispose();
|
||||
};
|
||||
}, []);
|
||||
@@ -884,7 +921,7 @@ export const useWorkflowGraph = ({
|
||||
.flatMap(category => category.nodes)
|
||||
.find(n => n.type === dragData.type);
|
||||
nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties
|
||||
|
||||
|
||||
// Create clean node data, only keep necessary fields
|
||||
const cleanNodeData = {
|
||||
id: `${dragData.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
@@ -1103,6 +1140,32 @@ export const useWorkflowGraph = ({
|
||||
})
|
||||
}
|
||||
|
||||
const handleAddNotes = () => {
|
||||
if (!graphRef.current) return;
|
||||
const nodeConfig: NodeProperties = JSON.parse(JSON.stringify(notesConfig));
|
||||
nodeConfig.config = {
|
||||
...nodeConfig.config,
|
||||
author: { type: 'define', defaultValue: user?.username || '' },
|
||||
};
|
||||
const cleanNodeData = {
|
||||
id: `notes_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
name: t('workflow.notes'),
|
||||
...nodeConfig,
|
||||
};
|
||||
const container = graphRef.current.container;
|
||||
const nodeW = graphNodeLibrary.notes?.width || nodeWidth;
|
||||
const nodeH = graphNodeLibrary.notes?.height || 100;
|
||||
const rect = container.getBoundingClientRect();
|
||||
const center = graphRef.current.clientToLocal(rect.left + rect.width / 2, rect.top + rect.height / 2);
|
||||
graphRef.current.addNode({
|
||||
...(graphNodeLibrary.notes || graphNodeLibrary.default),
|
||||
x: center.x - nodeW / 2,
|
||||
y: center.y - nodeH / 2,
|
||||
id: cleanNodeData.id,
|
||||
data: { ...cleanNodeData },
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
config,
|
||||
setConfig,
|
||||
@@ -1120,6 +1183,7 @@ export const useWorkflowGraph = ({
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
setChatVariables,
|
||||
handleAddNotes
|
||||
};
|
||||
};
|
||||
|
||||
@@ -38,7 +38,8 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
parseEvent,
|
||||
handleSave,
|
||||
chatVariables,
|
||||
setChatVariables
|
||||
setChatVariables,
|
||||
handleAddNotes
|
||||
} = useWorkflowGraph({ containerRef, miniMapRef });
|
||||
|
||||
const onDragOver = (event: React.DragEvent) => {
|
||||
@@ -95,6 +96,7 @@ const Workflow = forwardRef<WorkflowRef>((_props, ref) => {
|
||||
canRedo={canRedo}
|
||||
onUndo={onUndo}
|
||||
onRedo={onRedo}
|
||||
addNotes={handleAddNotes}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user