Merge branch 'feature/memory_zy' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/memory_zy

This commit is contained in:
zhaoying
2026-03-10 13:37:42 +08:00
115 changed files with 4582 additions and 1531 deletions

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1 REDIS_DB=1
# Celery (Using Redis as broker) # Celery (Using Redis as broker)
BROKER_URL=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BROKER=1
RESULT_BACKEND=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BACKEND=2
# JWT Secret Key (Formation method: openssl rand -hex 32) # JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here SECRET_KEY=your-secret-key-here

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1 REDIS_DB=1
# Celery (使用Redis作为broker) # Celery (使用Redis作为broker)
BROKER_URL=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BROKER=1
RESULT_BACKEND=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BACKEND=2
# JWT密钥 (生成方式: openssl rand -hex 32) # JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here SECRET_KEY=your-secret-key-here

View File

@@ -4,7 +4,9 @@ Memory 缓存模块
提供记忆系统相关的缓存功能 提供记忆系统相关的缓存功能
""" """
from .interest_memory import InterestMemoryCache from .interest_memory import InterestMemoryCache
from .activity_stats_cache import ActivityStatsCache
__all__ = [ __all__ = [
"InterestMemoryCache", "InterestMemoryCache",
"ActivityStatsCache",
] ]

View File

@@ -0,0 +1,124 @@
"""
Recent Activity Stats Cache
记忆提取活动统计缓存模块
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放
查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
ACTIVITY_STATS_CACHE_EXPIRE = 86400
class ActivityStatsCache:
"""记忆提取活动统计缓存类"""
PREFIX = "cache:memory:activity_stats"
@classmethod
def _get_key(cls, workspace_id: str) -> str:
"""生成 Redis key
Args:
workspace_id: 工作空间ID
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
@classmethod
async def set_activity_stats(
cls,
workspace_id: str,
stats: Dict[str, Any],
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
) -> bool:
"""设置记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
stats: 统计数据,格式:
{
"chunk_count": int,
"statements_count": int,
"triplet_entities_count": int,
"triplet_relations_count": int,
"temporal_count": int,
}
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(workspace_id)
payload = {
"stats": stats,
"generated_at": datetime.now().isoformat(),
"workspace_id": workspace_id,
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_activity_stats(
cls,
workspace_id: str,
) -> Optional[Dict[str, Any]]:
"""获取记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
统计数据字典,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(workspace_id)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}")
return payload
logger.info(f"活动统计缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_activity_stats(
cls,
workspace_id: str,
) -> bool:
"""删除记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
是否删除成功
"""
try:
key = cls._get_key(workspace_id)
result = await aio_redis.delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,27 +1,54 @@
import os import os
import platform import platform
from datetime import timedelta from datetime import timedelta
from celery.schedules import crontab
from urllib.parse import quote from urllib.parse import quote
from celery import Celery from celery import Celery
from celery.schedules import crontab from celery.schedules import crontab
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization # macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0 # broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB 10 # backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
celery_app = Celery( celery_app = Celery(
"redbear_tasks", "redbear_tasks",
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", broker=_broker_url,
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", backend=_backend_url,
) )
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks # Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks' celery_app.conf.task_default_queue = 'memory_tasks'
@@ -86,6 +113,8 @@ celery_app.conf.update(
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
}, },
) )

View File

@@ -1,10 +1,12 @@
import uuid import uuid
import io
from typing import Optional, Annotated from typing import Optional, Annotated
import yaml import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from urllib.parse import quote
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
from app.services.app_statistics_service import AppStatisticsService from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_dsl_service import AppDslService
router = APIRouter(prefix="/apps", tags=["Apps"]) router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger() logger = get_business_logger()
@@ -1010,3 +1013,57 @@ def get_workspace_api_statistics(
) )
return success(data=result) return success(data=result)
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
@cur_workspace_access_guard()
async def export_app(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
release_id: Optional[uuid.UUID] = None
):
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
release_id: 指定发布版本id不传则导出当前草稿配置。
"""
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
encoded = quote(filename, safe=".")
yaml_bytes = yaml_str.encode("utf-8")
file_stream = io.BytesIO(yaml_bytes)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/octet-stream; charset=utf-8",
headers={"Content-Disposition": f"attachment; filename={encoded}",
"Content-Length": str(len(yaml_bytes))}
)
@router.post("/import", summary="从 YAML 文件导入应用")
@cur_workspace_access_guard()
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
raw = (await file.read()).decode("utf-8")
dsl = yaml.safe_load(raw)
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -1,28 +1,29 @@
from typing import List, Optional from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success from app.core.response_utils import fail, success
from app.db import get_db from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey from app.models import ModelApiKey
from app.models.user_model import User from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService from app.repositories import knowledge_repository
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv() load_dotenv()
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -37,7 +38,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse) @router.get("/health/status", response_model=ApiResponse)
async def get_health_status( async def get_health_status(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get latest health status written by Celery periodic task Get latest health status written by Celery periodic task
@@ -55,8 +56,9 @@ async def get_health_status(
@router.get("/download_log") @router.get("/download_log")
async def download_log( async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), log_type: str = Query("file", regex="^(file|transmission)$",
current_user: User = Depends(get_current_user) description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
): ):
""" """
Download or stream agent service log file Download or stream agent service log file
@@ -75,16 +77,16 @@ async def download_log(
- transmission mode: StreamingResponse with SSE - transmission mode: StreamingResponse with SSE
""" """
api_logger.info(f"Log download requested with log_type={log_type}") api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity) # Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]: if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}") api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail( return fail(
BizCode.BAD_REQUEST, BizCode.BAD_REQUEST,
"无效的log_type参数", "无效的log_type参数",
"log_type必须是'file''transmission'" "log_type必须是'file''transmission'"
) )
# Route to appropriate mode # Route to appropriate mode
if log_type == "file": if log_type == "file":
# File mode: Return complete log file content # File mode: Return complete log file content
@@ -119,10 +121,10 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse) @router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server( async def write_server(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
@@ -136,11 +138,11 @@ async def write_server(
""" """
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
db=db, db=db,
@@ -149,7 +151,7 @@ async def write_server(
) )
if storage_type is None: storage_type = 'neo4j' if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = '' user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag': if storage_type == 'rag':
if workspace_id: if workspace_id:
@@ -161,13 +163,15 @@ async def write_server(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
else: else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
else: else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
@@ -175,7 +179,7 @@ async def write_server(
messages_list, messages_list,
config_id, config_id,
db, db,
storage_type, storage_type,
user_rag_memory_id, user_rag_memory_id,
language language
) )
@@ -195,10 +199,10 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse) @router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server_async( async def write_server_async(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
@@ -213,10 +217,11 @@ async def write_server_async(
""" """
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
@@ -244,7 +249,7 @@ async def write_server_async(
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
) )
api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交") return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e: except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}") api_logger.error(f"Async write operation failed: {str(e)}")
@@ -254,9 +259,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def read_server( async def read_server(
user_input: UserInput, user_input: UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Read service endpoint - processes read operations synchronously Read service endpoint - processes read operations synchronously
@@ -291,8 +296,9 @@ async def read_server(
) )
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.end_user_id, user_input.end_user_id,
@@ -306,7 +312,8 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
@@ -319,7 +326,7 @@ async def read_server(
db=db db=db
) )
if "信息不足,无法回答" in result['answer']: if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -335,9 +342,10 @@ async def read_server(
@router.post("/file", response_model=ApiResponse) @router.post("/file", response_model=ApiResponse)
async def file_update( async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"), files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"), model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
文件上传接口 - 支持图片识别 文件上传接口 - 支持图片识别
@@ -350,9 +358,6 @@ async def file_update(
Returns: Returns:
文件处理结果 文件处理结果
""" """
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}") api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0] apiConfig: ModelApiKey = config.api_keys[0]
@@ -361,7 +366,7 @@ async def file_update(
for file in files: for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}") api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read() content = await file.read()
if file.content_type and file.content_type.startswith("image/"): if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV( vision_model = QWenCV(
key=apiConfig.api_key, key=apiConfig.api_key,
@@ -375,12 +380,12 @@ async def file_update(
else: else:
api_logger.warning(f"Unsupported file type: {file.content_type}") api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]") file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content) result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}") api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功") return success(data=result_text, msg="转换文本成功")
except Exception as e: except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True) api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@@ -430,8 +435,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse) @router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result( async def get_read_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async read task Get the status and result of an async read task
@@ -452,7 +457,7 @@ async def get_read_task_result(
try: try:
result = task_service.get_task_memory_read_result(task_id) result = task_service.get_task_memory_read_result(task_id)
status = result.get("status") status = result.get("status")
if status == "SUCCESS": if status == "SUCCESS":
# 任务成功完成 # 任务成功完成
task_result = result.get("result", {}) task_result = result.get("result", {})
@@ -470,7 +475,7 @@ async def get_read_task_result(
else: else:
# 旧格式:直接返回结果 # 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成") return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE": elif status == "FAILURE":
# 任务失败 # 任务失败
error_info = result.get("result", "Unknown error") error_info = result.get("result", "Unknown error")
@@ -479,7 +484,7 @@ async def get_read_task_result(
else: else:
error_msg = str(error_info) error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg) return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]: elif status in ["PENDING", "STARTED"]:
# 任务进行中 # 任务进行中
return success( return success(
@@ -499,7 +504,7 @@ async def get_read_task_result(
}, },
msg=f"任务状态: {status}" msg=f"任务状态: {status}"
) )
except Exception as e: except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True) api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -507,8 +512,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse) @router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result( async def get_write_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async write task Get the status and result of an async write task
@@ -529,7 +534,7 @@ async def get_write_task_result(
try: try:
result = task_service.get_task_memory_write_result(task_id) result = task_service.get_task_memory_write_result(task_id)
status = result.get("status") status = result.get("status")
if status == "SUCCESS": if status == "SUCCESS":
# 任务成功完成 # 任务成功完成
task_result = result.get("result", {}) task_result = result.get("result", {})
@@ -547,7 +552,7 @@ async def get_write_task_result(
else: else:
# 旧格式:直接返回结果 # 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成") return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE": elif status == "FAILURE":
# 任务失败 # 任务失败
error_info = result.get("result", "Unknown error") error_info = result.get("result", "Unknown error")
@@ -556,7 +561,7 @@ async def get_write_task_result(
else: else:
error_msg = str(error_info) error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg) return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]: elif status in ["PENDING", "STARTED"]:
# 任务进行中 # 任务进行中
return success( return success(
@@ -576,7 +581,7 @@ async def get_write_task_result(
}, },
msg=f"任务状态: {status}" msg=f"任务状态: {status}"
) )
except Exception as e: except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True) api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -584,9 +589,9 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse) @router.post("/status_type", response_model=ApiResponse)
async def status_type( async def status_type(
user_input: Write_UserInput, user_input: Write_UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -629,9 +634,10 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse) @router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api( async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"), only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0 - 如果用户没有当前工作空间,对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try: try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数 # 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats( result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api(
current_workspace_id=current_user.current_workspace_id, current_workspace_id=current_user.current_workspace_id,
db=db db=db
) )
return success(data=result, msg="获取知识库类型统计成功") return success(data=result, msg="获取知识库类型统计成功")
except Exception as e: except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}") api_logger.error(f"Knowledge type stats failed: {str(e)}")
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api( async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"), end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"), limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取指定用户的兴趣分布标签 获取指定用户的兴趣分布标签
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse) @router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api( async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -756,17 +757,17 @@ async def get_user_profile_api(
# ): # ):
# """ # """
# Get parsed API documentation (Public endpoint - no authentication required) # Get parsed API documentation (Public endpoint - no authentication required)
# Args: # Args:
# file_path: Optional path to API docs file. If None, uses default path. # file_path: Optional path to API docs file. If None, uses default path.
# Returns: # Returns:
# Parsed API documentation including title, meta info, and sections # Parsed API documentation including title, meta info, and sections
# """ # """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}") # api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try: # try:
# result = await memory_agent_service.get_api_docs(file_path) # result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"): # if result.get("success"):
# return success(msg=result["msg"], data=result["data"]) # return success(msg=result["msg"], data=result["data"])
# else: # else:
@@ -782,9 +783,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config( async def get_end_user_connected_config(
end_user_id: str, end_user_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取终端用户关联的记忆配置 获取终端用户关联的记忆配置
@@ -803,9 +804,9 @@ async def get_end_user_connected_config(
from app.services.memory_agent_service import ( from app.services.memory_agent_service import (
get_end_user_connected_config as get_config, get_end_user_connected_config as get_config,
) )
api_logger.info(f"Getting connected config for end_user: {end_user_id}") api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try: try:
result = get_config(end_user_id, db) result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功") return success(data=result, msg="获取终端用户关联配置成功")
@@ -814,4 +815,4 @@ async def get_end_user_connected_config(
return fail(BizCode.NOT_FOUND, str(e)) return fail(BizCode.NOT_FOUND, str(e))
except Exception as e: except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True) api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -149,6 +149,21 @@ async def get_workspace_end_users(
return {uid: {"total": 0} for uid in end_user_ids} return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
"app.tasks.init_implicit_emotions_for_users",
kwargs={"end_user_ids": end_user_ids},
)
_celery_app.send_task(
"app.tasks.init_interest_distribution_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询 # 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather( memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(), get_memory_configs(),

View File

@@ -2,7 +2,7 @@ from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
@@ -85,6 +85,7 @@ def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
@@ -99,7 +100,29 @@ def create_config(
svc = DataConfigService(db) svc = DataConfigService(db)
result = svc.create(payload) result = svc.create(payload)
return success(data=result, msg="创建成功") return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e: except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -521,10 +544,11 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info("Recent activity stats requested") workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
try: try:
result = await analytics_recent_activity_stats() result = await analytics_recent_activity_stats(workspace_id=workspace_id)
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")

View File

@@ -371,6 +371,11 @@ def update_model(
if model_data.type is not None or model_data.provider is not None: if model_data.type is not None or model_data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER) raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
if model_data.is_active:
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
if not active_keys:
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
try: try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}") api_logger.debug(f"开始更新模型配置: model_id={model_id}")

View File

@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
from urllib.parse import quote from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
@@ -289,7 +289,8 @@ async def extract_ontology(
async def create_scene( async def create_scene(
request: SceneCreateRequest, request: SceneCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
): ):
"""创建本体场景 """创建本体场景
@@ -360,8 +361,18 @@ async def create_scene(
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e: except RuntimeError as e:
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) err_str = str(e)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
from app.core.language_utils import get_language_from_header
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
except Exception as e: except Exception as e:
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
@@ -661,7 +672,8 @@ async def get_scenes(
async def create_class( async def create_class(
request: ClassCreateRequest, request: ClassCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
): ):
"""创建本体类型 """创建本体类型
@@ -676,7 +688,7 @@ async def create_class(
ApiResponse: 包含创建的类型信息 ApiResponse: 包含创建的类型信息
""" """
from app.controllers.ontology_secondary_routes import create_class_handler from app.controllers.ontology_secondary_routes import create_class_handler
return await create_class_handler(request, db, current_user) return await create_class_handler(request, db, current_user, x_language_type)
@router.put("/class/{class_id}", response_model=ApiResponse) @router.put("/class/{class_id}", response_model=ApiResponse)

View File

@@ -7,7 +7,7 @@
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional
from fastapi import Depends from fastapi import Depends, Header
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
@@ -58,7 +58,7 @@ async def scenes_handler(
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None,
scene_name: Optional[str] = None, scene_name: Optional[str] = None,
page: Optional[int] = None, page: Optional[int] = None,
page_size: Optional[int] = None, pagesize: Optional[int] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
@@ -71,14 +71,14 @@ async def scenes_handler(
workspace_id: 工作空间ID可选默认当前用户工作空间 workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配) scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效 page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效) pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话 db: 数据库会话
current_user: 当前用户 current_user: 当前用户
""" """
operation = "search" if scene_name else "list" operation = "search" if scene_name else "list"
api_logger.info( api_logger.info(
f"Scene {operation} requested by user {current_user.id}, " f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
) )
try: try:
@@ -105,13 +105,13 @@ async def scenes_handler(
api_logger.warning(f"Invalid page number: {page}") api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1: if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid page_size: {page_size}") api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误 # 如果只提供了page或pagesize中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None): if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页) # 模糊搜索场景(支持分页)
@@ -119,17 +119,15 @@ async def scenes_handler(
total = len(scenes) total = len(scenes)
# 如果提供了分页参数,进行分页处理 # 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None: if page is not None and pagesize is not None:
start_idx = (page - 1) * page_size start_idx = (page - 1) * pagesize
end_idx = start_idx + page_size end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx] scenes = scenes[start_idx:end_idx]
# 构建响应 # 构建响应
items = [] items = []
for scene in scenes: for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0 type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse( items.append(SceneResponse(
@@ -141,17 +139,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id, workspace_id=scene.workspace_id,
created_at=scene.created_at, created_at=scene.created_at,
updated_at=scene.updated_at, updated_at=scene.updated_at,
classes_count=type_num classes_count=type_num,
is_system_default=scene.is_system_default
)) ))
# 构建响应(包含分页信息) # 构建响应(包含分页信息)
if page is not None and page_size is not None: if page is not None and pagesize is not None:
# 计算是否有下一页 hasnext = (page * pagesize) < total
hasnext = (page * page_size) < total
pagination_info = PaginationInfo( pagination_info = PaginationInfo(
page=page, page=page,
pagesize=page_size, pagesize=pagesize,
total=total, total=total,
hasnext=hasnext hasnext=hasnext
) )
@@ -165,28 +162,25 @@ async def scenes_handler(
) )
else: else:
# 获取所有场景(支持分页) # 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1: if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}") api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1: if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid page_size: {page_size}") api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误 # 如果只提供了page或pagesize中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None): if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size) scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应 # 构建响应
items = [] items = []
for scene in scenes: for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0 type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse( items.append(SceneResponse(
@@ -198,17 +192,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id, workspace_id=scene.workspace_id,
created_at=scene.created_at, created_at=scene.created_at,
updated_at=scene.updated_at, updated_at=scene.updated_at,
classes_count=type_num classes_count=type_num,
is_system_default=scene.is_system_default
)) ))
# 构建响应(包含分页信息) # 构建响应(包含分页信息)
if page is not None and page_size is not None: if page is not None and pagesize is not None:
# 计算是否有下一页 hasnext = (page * pagesize) < total
hasnext = (page * page_size) < total
pagination_info = PaginationInfo( pagination_info = PaginationInfo(
page=page, page=page,
pagesize=page_size, pagesize=pagesize,
total=total, total=total,
hasnext=hasnext hasnext=hasnext
) )
@@ -238,7 +231,8 @@ async def scenes_handler(
async def create_class_handler( async def create_class_handler(
request: ClassCreateRequest, request: ClassCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
): ):
"""创建本体类型(统一使用列表形式,支持单个或批量)""" """创建本体类型(统一使用列表形式,支持单个或批量)"""
@@ -271,8 +265,11 @@ async def create_class_handler(
] ]
if count == 1: if count == 1:
# 单个创建 # 单个创建 - 先检查重名
class_data = classes_data[0] class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class( ontology_class = service.create_class(
scene_id=request.scene_id, scene_id=request.scene_id,
class_name=class_data["class_name"], class_name=class_data["class_name"],
@@ -330,12 +327,36 @@ async def create_class_handler(
return success(data=response.model_dump(mode='json'), msg="批量创建完成") return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e: except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}") err_str = str(e)
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e: except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) err_str = str(e)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e: except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
@@ -615,6 +636,7 @@ async def classes_handler(
scene_id=scene_uuid, scene_id=scene_uuid,
scene_name=scene.scene_name, scene_name=scene.scene_name,
scene_description=scene.scene_description, scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items items=items
) )

View File

@@ -97,6 +97,12 @@ async def create_tool(
): ):
"""创建工具""" """创建工具"""
try: try:
# 将 MCP 来源字段合并进 config
if request.tool_type == ToolType.MCP:
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = service.create_tool( tool_id = service.create_tool(
name=request.name, name=request.name,
tool_type=request.tool_type, tool_type=request.tool_type,

View File

@@ -1,7 +1,6 @@
import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, Dict, Optional from typing import Annotated, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
@@ -190,8 +189,12 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal) # Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) # 详见 docs/celery-env-bug-report.md
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
# SMTP Email Configuration # SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com") SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")

View File

@@ -1,10 +1,10 @@
import os
import json import json
import os
import time import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
structured = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
@@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"error_type": type(e).__name__, "error_type": type(e).__name__,
"error_message": str(e), "error_message": str(e),
"content_length": len(content), "content_length": len(content),
"llm_model_id": memory_config.llm_model_id if memory_config else None "llm_model_id": str(memory_config.llm_model_id) if memory_config else None
} }
logger.error(f"Split_The_Problem error details: {error_details}") logger.error(f"Split_The_Problem error details: {error_details}")
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
response_content = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, response_content = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
@@ -220,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
"error_type": type(e).__name__, "error_type": type(e).__name__,
"error_message": str(e), "error_message": str(e),
"questions_count": len(databasets), "questions_count": len(databasets),
"llm_model_id": memory_config.llm_model_id if memory_config else None "llm_model_id": str(memory_config.llm_model_id) if memory_config else None
} }
logger.error(f"Problem_Extension error details: {error_details}") logger.error(f"Problem_Extension error details: {error_details}")

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 ===== # ===== 第三方库 =====
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import ( from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync, create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool, create_time_retrieval_tool,
extract_tool_message_content, extract_tool_message_content,
) )
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state): async def rag_config(state):
@@ -50,10 +45,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState: async def llm_infomation(state: ReadState) -> ReadState:
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
# 收集所有内容 # 收集所有内容
content_list = [] content_list = []
# 处理重排序结果 # 处理重排序结果
reranked = results.get('reranked_results', {}) reranked = results.get('reranked_results', {})
if reranked: if reranked:
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str): elif isinstance(item, str):
text_parts.append(item) text_parts.append(item)
return '\n'.join(text_parts).strip() return '\n'.join(text_parts).strip()
except Exception as e: except Exception as e:
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve_nodes(state: ReadState) -> ReadState:
''' '''
模型信息 模型信息
''' '''
problem_extension=state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id=state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original = state.get('data', '')
problem_list=[] problem_list = []
for key,values in problem_extension.items(): for key, values in problem_extension.items():
for data in values: for data in values:
problem_list.append(data) problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题
async def process_question_nodes(idx, question): async def process_question_nodes(idx, question):
try: try:
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
send_verify = [] send_verify = []
for i, j in zip(keys, val, strict=False): for i, j in zip(keys, val, strict=False):
if j!=['']: if j != ['']:
send_verify.append({ send_verify.append({
"Query_small": i, "Query_small": i,
"Answer_Small": j "Answer_Small": j
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
} }
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases} return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id # 从state中获取end_user_id
import time import time
start=time.time() start = time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器 with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
return await llm_infomation(state) return await llm_infomation(state)
llm_config = await get_llm_info() llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0] api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key api_key = api_key_obj.api_key
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
) )
time_retrieval_tool = create_time_retrieval_tool(end_user_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True } search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发 async with SEMAPHORE: # 限制并发
try: try:
if storage_type == "rag" and user_rag_memory_id: if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else: else:
cleaned_query = question cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke # 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4) # json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases} return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os import os
import time import time
@@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin): class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类""" """总结节点服务类"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def rag_config(state): async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = { kb_config = {
@@ -51,10 +51,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -62,25 +64,28 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
""" """
增强的summary_llm函数包含更好的错误处理和数据验证 增强的summary_llm函数包含更好的错误处理和数据验证
""" """
data = state.get("data", '') data = state.get("data", '')
# 构建系统提示词 # 构建系统提示词
if str(search_mode) == "0": if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template( system_prompt = await summary_service.template_service.render_template(
@@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
) )
try: try:
# 使用优化的LLM服务进行结构化输出 # 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await summary_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=response_model, system_prompt=system_prompt,
fallback_value=None response_model=response_model,
) fallback_value=None
)
# 验证结构化响应 # 验证结构化响应
if structured is None: if structured is None:
logger.warning("LLM返回None使用默认回答") logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答" return "信息不足,无法回答"
# 根据操作类型提取答案 # 根据操作类型提取答案
if operation_name == "summary": if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
@@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
else: else:
logger.warning("结构化响应缺少data字段") logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答" aimessages = "信息不足,无法回答"
# 验证答案不为空 # 验证答案不为空
if not aimessages or aimessages.strip() == "": if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答" aimessages = "信息不足,无法回答"
return aimessages return aimessages
except Exception as e: except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True) logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback # 尝试非结构化输出作为fallback
try: try:
logger.info("尝试非结构化输出作为fallback") logger.info("尝试非结构化输出作为fallback")
@@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
system_prompt=system_prompt, system_prompt=system_prompt,
fallback_message="信息不足,无法回答" fallback_message="信息不足,无法回答"
) )
if response and response.strip(): if response and response.strip():
# 简单清理响应 # 简单清理响应
cleaned_response = response.strip() cleaned_response = response.strip()
@@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
if cleaned_response.startswith('```'): if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n') lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1]) cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response return cleaned_response
else: else:
return "信息不足,无法回答" return "信息不足,无法回答"
except Exception as fallback_error: except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}") logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答" return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功") logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'') async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
data=state.get("data", '') storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = { input_summary = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
} }
retrieve={ retrieve = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id, "user_rag_memory_id": user_rag_memory_id,
"_intermediate": { "_intermediate": {
"type": "retrieval_summary", "type": "retrieval_summary",
"title":"快速检索", "title": "快速检索",
"summary": aimessages, "summary": aimessages,
"query": data, "query": data,
"storage_type": storage_type, "storage_type": storage_type,
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
} }
} }
return input_summary,retrieve return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState: async def Input_Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
storage_type=state.get("storage_type",'') storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
data=state.get("data", '') data = state.get("data", '')
end_user_id=state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history(state)
search_params = { search_params = {
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
} }
try: try:
if storage_type!="rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else: else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e: except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, [] retrieve_info, question, raw_results = "", data, []
try: try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0] summary = summary_result[0]
except Exception as e: except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True ) logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary= { summary = {
"status": "fail", "status": "fail",
"summary_result": "信息不足,无法回答", "summary_result": "信息不足,无法回答",
"storage_type": storage_type, "storage_type": storage_type,
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
except Exception: except Exception:
duration = 0.0 duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary":summary} return {"summary": summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '') async def Retrieve_Summary(state: ReadState) -> ReadState:
history = await summary_history( state) retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json import json
with open("检索.json","w",encoding='utf-8') as f: with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", []) retrieve = retrieve.get("Expansion_issue", [])
start=time.time() start = time.time()
retrieve_info_str=[] retrieve_info_str = []
for data in retrieve: for data in retrieve:
if data=='': if data == '':
retrieve_info_str='' retrieve_info_str = ''
else: else:
for key, value in data.items(): for key, value in data.items():
if key=='Answer_Small': if key == 'Answer_Small':
for i in value: for i in value:
retrieve_info_str.append(i) retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str, aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
except Exception: except Exception:
duration = 0.0 duration = 0.0
log_time('Retrieval summary', duration) log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary(state: ReadState)-> ReadState: async def Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
query = state.get("data", '') query = state.get("data", '')
verify=state.get("verify", '') verify = state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '') verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str='' retrieve_info_str = ''
for data in verify_expansion_issue: for data in verify_expansion_issue:
for key, value in data.items(): for key, value in data.items():
if key=='answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str+=i+'\n' retrieve_info_str += i + '\n'
history=await summary_history(state) history = await summary_history(state)
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages=await summary_llm(state,history,data, aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2','summary',SummaryResponse,0) 'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '') async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state) history = await summary_history(state)
query = state.get("data", '') query = state.get("data", '')
verify = state.get("verify", '') verify = state.get("verify", '')
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages = await summary_llm(state, history, data, aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= { result = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
return {"summary":result} return {"summary": result}

View File

@@ -1,8 +1,9 @@
import asyncio
import os import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin): class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类""" """验证节点服务类"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
verification_service = VerificationNodeService() verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式""" """处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '') data = state.get('data', '')
# 将 VerificationItem 对象转换为字典列表 # 将 VerificationItem 对象转换为字典列表
verified_data = [] verified_data = []
if messages_deal.expansion_issue: if messages_deal.expansion_issue:
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
verified_data.append(item.model_dump()) verified_data.append(item.model_dump())
elif isinstance(item, dict): elif isinstance(item, dict):
verified_data.append(item) verified_data.append(item)
Verify_result = { Verify_result = {
"status": messages_deal.split_result, "status": messages_deal.split_result,
"verified_data": verified_data, "verified_data": verified_data,
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
} }
} }
return Verify_result return Verify_result
async def Verify(state: ReadState): async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
content = state.get('data', '') content = state.get('data', '')
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = { messages = {
"Query": content, "Query": content,
"Expansion_issue": retrieve_expansion "Expansion_issue": retrieve_expansion
} }
logger.info("Verify: 开始渲染模板") logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = VerificationResult.model_json_schema() json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template( system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2', template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt', operation_name='split_verify_prompt',
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
json_schema=json_schema json_schema=json_schema
) )
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}") logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护 # 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM") logger.info("Verify: 开始调用 LLM")
try: try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待 # 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for( with get_db_context() as db_session:
verification_service.call_llm_structured( structured = await asyncio.wait_for(
state=state, verification_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=VerificationResult, system_prompt=system_prompt,
fallback_value={ response_model=VerificationResult,
"query": content, fallback_value={
"history": history if isinstance(history, list) else [], "query": content,
"expansion_issue": [], "history": history if isinstance(history, list) else [],
"split_result": "failed", "expansion_issue": [],
"reason": "验证失败或超时" "split_result": "failed",
} "reason": "验证失败或超时"
), }
timeout=150.0 # 150秒超时 ),
) timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}") logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值") logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
split_result="failed", split_result="failed",
reason="LLM调用超时" reason="LLM调用超时"
) )
result = await Verify_prompt(state, structured) result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===") logger.info("=== Verify 节点执行完成 ===")
return {"verify": result} return {"verify": result}
except Exception as e: except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True) logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果 # 返回失败的验证结果
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
"user_rag_memory_id": state.get('user_rag_memory_id', '') "user_rag_memory_id": state.get('user_rag_memory_id', '')
} }
} }
} }

View File

@@ -1,3 +1,4 @@
from app.cache.memory.interest_memory import InterestMemoryCache
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id=end_user_id,
language=lang,
)
if deleted:
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
write_result = { write_result = {
"status": "success", "status": "success",
"data": structured_messages, "data": structured_messages,

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END from langgraph.constants import START, END
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db from app.db import get_db
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
) )
@asynccontextmanager @asynccontextmanager
async def make_read_graph(): async def make_read_graph():
"""创建并返回 LangGraph 工作流""" """创建并返回 LangGraph 工作流"""
@@ -49,7 +47,7 @@ async def make_read_graph():
workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary) workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails) workflow.add_node("Summary_fails", Summary_fails)
# 添加边 # 添加边
workflow.add_edge(START, "content_input") workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue) workflow.add_conditional_edges("content_input", Split_continue)
@@ -62,20 +60,20 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----''' '''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
# 编译工作流 # 编译工作流
graph = workflow.compile() graph = workflow.compile()
yield graph yield graph
except Exception as e: except Exception as e:
print(f"创建工作流失败: {e}") print(f"创建工作流失败: {e}")
raise raise
finally: finally:
print("工作流创建完成") print("工作流创建完成")
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
@@ -92,17 +90,19 @@ async def main():
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
import time import time
start=time.time() start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} "end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
async for update_event in graph.astream( async for update_event in graph.astream(
initial_state, initial_state,
stream_mode="updates", stream_mode="updates",
@@ -110,7 +110,7 @@ async def main():
): ):
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}") print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构 # 处理不同Summary节点的返回结构
if 'Summary' in node_name: if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
@@ -125,23 +125,22 @@ async def main():
spit_data = node_data.get('spit_data', {}).get('_intermediate', None) spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}: if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data) _intermediate_outputs.append(spit_data)
# Problem_Extension 节点 # Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}: if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension) _intermediate_outputs.append(problem_extension)
# Retrieve 节点 # Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}: if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node) _intermediate_outputs.extend(retrieve_node)
# Verify 节点 # Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None) verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n) _intermediate_outputs.append(verify_n)
# Summary 节点 # Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None) summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}: if summary_n and summary_n != [] and summary_n != {}:
@@ -161,17 +160,20 @@ async def main():
# #
print(f"=== 最终摘要 ===") print(f"=== 最终摘要 ===")
print(summary) print(summary)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally:
db_session.close()
end=time.time() end = time.time()
print(100*'y') print(100 * 'y')
print(f"总耗时: {end-start}s") print(f"总耗时: {end - start}s")
print(100*'y') print(100 * 'y')
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -82,7 +82,9 @@ async def get_chunked_dialogs(
pruning_config = PruningConfig( pruning_config = PruningConfig(
pruning_switch=memory_config.pruning_enabled, pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene or "education", pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_classes=memory_config.ontology_classes,
) )
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -225,5 +225,24 @@ async def write(
with open(log_file, "a", encoding="utf-8") as f: with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
"temporal_count": 0,
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("=== Pipeline Complete ===") logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds") logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -10,7 +10,7 @@ Classes:
TemporalSearchParams: Parameters for temporal search queries TemporalSearchParams: Parameters for temporal search queries
""" """
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -55,17 +55,26 @@ class PruningConfig(BaseModel):
Attributes: Attributes:
pruning_switch: Enable or disable semantic pruning pruning_switch: Enable or disable semantic pruning
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound') pruning_scene: Scene name for pruning, either a built-in key
('education', 'online_service', 'outbound') or a custom scene_name
from ontology_scene table
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
scene_id: Optional ontology scene UUID, used to load custom ontology classes
ontology_classes: List of class_name strings from ontology_class table,
injected into the prompt when pruning_scene is not a built-in scene
""" """
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
pruning_scene: str = Field( pruning_scene: str = Field(
"education", "education",
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.", description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
) )
pruning_threshold: float = Field( pruning_threshold: float = Field(
0.5, ge=0.0, le=0.9, 0.5, ge=0.0, le=0.9,
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
ontology_classes: Optional[List[str]] = Field(
None, description="Class names from ontology_class table for custom scenes."
)
class TemporalSearchParams(BaseModel): class TemporalSearchParams(BaseModel):

View File

@@ -86,19 +86,26 @@ class SemanticPruner:
self._detailed_prune_logging = True # 是否启用详细日志 self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载场景特定配置 # 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
self.config.pruning_scene, self.config.pruning_scene,
fallback_to_generic=True fallback_to_generic=True
) )
# 检查场景是否有专门支持 # 判断是否为内置专门场景
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
if is_supported:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置") # 自定义场景的本体类型列表(用于注入提示词)
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
if self._is_builtin_scene:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
else: else:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)") self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}") if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else:
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
# Load Jinja2 template # Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2") self.template = prompt_env.get_template("extracat_Pruning.jinja2")
@@ -424,12 +431,16 @@ class SemanticPruner:
self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目") self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目")
rendered = self.template.render( rendered = self.template.render(
pruning_scene=self.config.pruning_scene, pruning_scene=self.config.pruning_scene,
is_builtin_scene=self._is_builtin_scene,
ontology_classes=self._ontology_classes,
dialog_text=dialog_text, dialog_text=dialog_text,
language=self.language language=self.language
) )
log_template_rendering("extracat_Pruning.jinja2", { log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene, "pruning_scene": self.config.pruning_scene,
"is_builtin_scene": self._is_builtin_scene,
"ontology_classes_count": len(self._ontology_classes),
"language": self.language "language": self.language
}) })
log_prompt_rendering("pruning-extract", rendered) log_prompt_rendering("pruning-extract", rendered)

View File

@@ -1,6 +1,6 @@
{# {#
对话级抽取与相关性判定模板(用于剪枝加速) 对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, dialog_text 输入pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
输出:严格 JSON不要包含任何多余文本字段 输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关 - is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
@@ -16,7 +16,8 @@
- 仅输出上述键;避免多余解释或字段。 - 仅输出上述键;避免多余解释或字段。
#} #}
{% set scene_instructions = { {# ── 内置场景的固定说明 ── #}
{% set builtin_scene_instructions = {
'education': { 'education': {
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。', 'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.' 'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
@@ -31,16 +32,40 @@
} }
} %} } %}
{% set scene_key = pruning_scene %} {# ── 确定最终使用的场景说明 ── #}
{% if scene_key not in scene_instructions %} {% if is_builtin_scene %}
{% set scene_key = 'education' %} {# 内置专门场景:使用固定说明 #}
{% set scene_key = pruning_scene %}
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
{% set custom_types_str = '' %}
{% else %}
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
{% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %}
{% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %}
{% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %}
{% else %}
{# 无本体类型时退化为通用说明 #}
{% if language == 'en' %}
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% else %}
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %}
{% set custom_types_str = '' %}
{% endif %}
{% endif %} {% endif %}
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
{% if language == "zh" %} {% if language == "zh" %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性: 请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
场景说明:{{ instruction }} 场景说明:{{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }}相关的内容即判定为相关is_related=true
{% endif %}
对话全文: 对话全文:
""" """
@@ -60,6 +85,9 @@
{% else %} {% else %}
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario: Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
Scenario Description: {{ instruction }} Scenario Description: {{ instruction }}
{% if not is_builtin_scene and custom_types_str %}
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
{% endif %}
Full Dialogue: Full Dialogue:
""" """

View File

@@ -8,34 +8,60 @@ from typing import Any
from urllib.parse import quote from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ from app.core.workflow.adapters.errors import (
UnsupportVariableType,
UnknowModelWarning,
ExceptionDefineition,
ExceptionType ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig from app.core.workflow.nodes.configs import (
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig StartNodeConfig,
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ LLMNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
EndNodeConfig,
HttpRequestNodeConfig,
IfElseNodeConfig,
JinjaRenderNodeConfig,
KnowledgeRetrievalNodeConfig,
NoteNodeConfig,
ParameterExtractorNodeConfig,
QuestionClassifierNodeConfig,
VariableAggregatorNodeConfig
)
from app.core.workflow.nodes.cycle_graph.config import (
ConditionDetail as LoopConditionDetail,
ConditionsConfig,
CycleVariable CycleVariable
from app.core.workflow.nodes.end import EndNodeConfig )
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ from app.core.workflow.nodes.enums import (
HttpContentType, HttpErrorHandle ValueInputType,
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig ComparisonOperator,
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ AssignmentOperator,
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig HttpAuthType,
from app.core.workflow.nodes.if_else import IfElseNodeConfig HttpContentType,
HttpErrorHandle,
NodeType
)
from app.core.workflow.nodes.http_request.config import (
HttpAuthConfig,
HttpContentTypeConfig,
HttpFormData,
HttpTimeOutConfig,
HttpRetryConfig,
HttpErrorDefaultTamplete,
HttpErrorHandleConfig
)
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
def __init__(self): def __init__(self):
self.CONFIG_CONVERT_MAP = { self.CONFIG_CONVERT_MAP = {
"start": self.convert_start_node_config, NodeType.START: self.convert_start_node_config,
"llm": self.convert_llm_node_config, NodeType.LLM: self.convert_llm_node_config,
"answer": self.convert_end_node_config, NodeType.END: self.convert_end_node_config,
"if-else": self.convert_if_else_node_config, NodeType.IF_ELSE: self.convert_if_else_node_config,
"loop": self.convert_loop_node_config, NodeType.LOOP: self.convert_loop_node_config,
"iteration": self.convert_iteration_node_config, NodeType.ITERATION: self.convert_iteration_node_config,
"assigner": self.convert_assigner_node_config, NodeType.ASSIGNER: self.convert_assigner_node_config,
"code": self.convert_code_node_config, NodeType.CODE: self.convert_code_node_config,
"http-request": self.convert_http_node_config, NodeType.HTTP_REQUEST: self.convert_http_node_config,
"template-transform": self.convert_jinja_render_node_config, NodeType.JINJARENDER: self.convert_jinja_render_node_config,
"knowledge-retrieval": self.convert_knowledge_node_config, NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config, NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config, NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator_node_config, NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
"tool": self.convert_tool_node_config, NodeType.TOOL: self.convert_tool_node_config,
"loop-start": lambda x: {}, NodeType.NOTES: self.convert_notes_config,
"iteration-start": lambda x: {}, NodeType.CYCLE_START: lambda x: {},
"loop-end": lambda x: {}, NodeType.BREAK: lambda x: {},
} }
def get_node_convert(self, node_type): def get_node_convert(self, node_type):
@@ -129,11 +155,11 @@ class DifyConverter(BaseConverter):
@staticmethod @staticmethod
def _convert_file(var): def _convert_file(var):
pass return None
@staticmethod @staticmethod
def _convert_array_file(var): def _convert_array_file(var):
pass return []
@staticmethod @staticmethod
def variable_type_map(source_type) -> VariableType | None: def variable_type_map(source_type) -> VariableType | None:
@@ -185,6 +211,9 @@ class DifyConverter(BaseConverter):
"not empty": ComparisonOperator.NOT_EMPTY, "not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH, "start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH, "end with": ComparisonOperator.END_WITH,
"not contains": ComparisonOperator.NOT_CONTAINS,
"exists": ComparisonOperator.NOT_EMPTY,
"not exists": ComparisonOperator.EMPTY
} }
return operator_map.get(operator, operator) return operator_map.get(operator, operator)
@@ -198,7 +227,7 @@ class DifyConverter(BaseConverter):
"over-write": AssignmentOperator.COVER, "over-write": AssignmentOperator.COVER,
"remove-last": AssignmentOperator.REMOVE_LAST, "remove-last": AssignmentOperator.REMOVE_LAST,
"remove-first": AssignmentOperator.REMOVE_FIRST, "remove-first": AssignmentOperator.REMOVE_FIRST,
"set": AssignmentOperator.ASSIGN,
} }
return operator_map.get(operator, operator) return operator_map.get(operator, operator)
@@ -267,10 +296,10 @@ class DifyConverter(BaseConverter):
type=var_type, type=var_type,
required=var["required"], required=var["required"],
default=self.convert_variable_type( default=self.convert_variable_type(
var_type, var["default"] var_type, var.get("default")
), ),
description=var["label"], description=var["label"],
max_length=var.get("max_length"), max_length=var.get("max_length", 50),
) )
start_vars.append(var_def) start_vars.append(var_def)
result = StartNodeConfig.model_construct( result = StartNodeConfig.model_construct(
@@ -333,7 +362,7 @@ class DifyConverter(BaseConverter):
MessageConfig( MessageConfig(
role="user", role="user",
content=self.trans_variable_format( content=self.trans_variable_format(
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}") node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}"
) )
) )
) )
@@ -364,7 +393,7 @@ class DifyConverter(BaseConverter):
node_data = node["data"] node_data = node["data"]
cases = [] cases = []
for case in node_data["cases"]: for case in node_data["cases"]:
case_id = case["id"] case_id = case.get("id") or case.get("case_id")
logical_operator = case["logical_operator"] logical_operator = case["logical_operator"]
conditions = [] conditions = []
for condition in case["conditions"]: for condition in case["conditions"]:
@@ -540,7 +569,8 @@ class DifyConverter(BaseConverter):
] = self.trans_variable_format(content["value"]) ] = self.trans_variable_format(content["value"])
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = node_data["body"]["data"][0]["value"] body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -612,7 +642,7 @@ class DifyConverter(BaseConverter):
), ),
headers=headers, headers=headers,
params=params, params=params,
verify_ssl=node_data["ssl_verify"], verify_ssl=node_data.get("ssl_verify", False),
timeouts=HttpTimeOutConfig.model_construct( timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5, read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
@@ -696,7 +726,7 @@ class DifyConverter(BaseConverter):
group_variables = {} group_variables = {}
group_type = {} group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables["output"] = [ group_variables = [
self._process_list_variable_litearl(variable) self._process_list_variable_litearl(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
@@ -728,3 +758,16 @@ class DifyConverter(BaseConverter):
detail=f"Please reconfigure the tool node.", detail=f"Please reconfigure the tool node.",
)) ))
return {} return {}
@staticmethod
def convert_notes_config(node: dict):
node_data = node["data"]
result = NoteNodeConfig.model_construct(
author=node_data.get("author", ""),
text=node_data.get("text", ""),
width=node_data.get("width", 80),
height=node_data.get("height", 80),
theme=node_data.get("theme", "blue"),
show_author=node_data.get("showAuthor", True)
).model_dump()
return result

View File

@@ -44,12 +44,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR, "parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER, "question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR, "variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL "tool": NodeType.TOOL,
"": NodeType.NOTES
} }
def __init__(self, config: dict[str, Any]): def __init__(self, config: dict[str, Any]):
DifyConverter.__init__(self) DifyConverter.__init__(self)
BasePlatformAdapter.__init__(self, config) BasePlatformAdapter.__init__(self, config)
def get_metadata(self) -> PlatformMetadata: def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
@@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
support_node_types=list(self.NODE_TYPE_MAPPING.keys()) support_node_types=list(self.NODE_TYPE_MAPPING.keys())
) )
def map_node_type(self, platform_node_type) -> str: def map_node_type(self, platform_node_type) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN) return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property @property
@@ -83,6 +84,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition(
type=ExceptionType.PLATFORM,
detail="workflow mode is not supported"
))
return False
for node in self.origin_nodes: for node in self.origin_nodes:
if not self._valid_nodes(node): if not self._valid_nodes(node):
@@ -134,6 +141,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
for node in self.origin_nodes: for node in self.origin_nodes:
if self.map_node_type(node["data"]["type"]) == NodeType.LLM: if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output" self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL:
self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output"
def _convert_cycle_node_position(self, node_id: str, position: dict): def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes: for node in self.origin_nodes:
@@ -154,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_data = node["data"] node_data = node["data"]
try: try:
node_type = self.map_node_type(node_data["type"])
return NodeDefinition( return NodeDefinition(
id=node["id"], id=node["id"],
type=self.map_node_type(node_data["type"]), type=node_type,
name=node_data.get("title"), name=node_data.get("title") or "notes",
cycle=node.get("parentId"), cycle=node.get("parentId"),
description=None, description=None,
config=self._convert_node_config(node), config=self._convert_node_config(node_type, node),
position={ position={
"x": node["position"]["x"], "x": node["position"]["x"],
"y": node["position"]["y"] "y": node["position"]["y"]
@@ -174,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
except Exception as e: except Exception as e:
logger.debug(f"convert node error - {e}", exc_info=True) logger.debug(f"convert node error - {e}", exc_info=True)
def _convert_node_config(self, node: dict): def _convert_node_config(self, node_type: NodeType, node: dict):
node_data = node["data"]
node_type = node_data["type"]
try: try:
node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type not in self.CONFIG_CONVERT_MAP: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
detail=f"node type {node_type} is unsupported", detail=f"node type {node_data.get('type')} is unsupported",
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
@@ -201,16 +210,15 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
edge_id = edge["id"]
label = None label = None
if source in self.branch_node_cache: if source in self.branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2]) case_id = edge["sourceHandle"]
if case_id == "false": if case_id == "false":
label = f'CASE{len(self.branch_node_cache[source])+1}' label = f'CASE{len(self.branch_node_cache[source]) + 1}'
else: else:
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
if source in self.error_branch_node_cache: if source in self.error_branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2]) case_id = edge["sourceHandle"]
if case_id == "source": if case_id == "source":
label = "SUCCESS" label = "SUCCESS"
else: else:
@@ -235,6 +243,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
name=variable["name"], name=variable["name"],
default=variable["value"], default=variable["value"],
type=self.variable_type_map(variable["value_type"]), type=self.variable_type_map(variable["value_type"]),
description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(
@@ -248,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
return ExecutionConfig() return ExecutionConfig()

View File

@@ -292,6 +292,8 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") cycle_node = node.get("cycle")
if cycle_node: if cycle_node:
@@ -320,7 +322,7 @@ class GraphBuilder:
# Used later to determine which branch to take based on the node's output # Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label # Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
if node_instance: if node_instance:
# Wrap node's run method to avoid closure issues # Wrap node's run method to avoid closure issues

View File

@@ -158,18 +158,36 @@ class WorkflowExecutor:
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
# Append messages for user and assistant # Append messages for user and assistant
result["messages"].extend( if input_data.get("files"):
[ result["messages"].extend(
{ [
"role": "user", {
"content": input_data.get("message", '') "role": "user",
}, "content": input_data.get("message", '')
{ },
"role": "assistant", {
"content": full_content "role": "user",
} "content": input_data.get("files")
] },
) {
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# Calculate elapsed time # Calculate elapsed time
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
@@ -308,18 +326,36 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
# Append messages for user and assistant # Append messages for user and assistant
result["messages"].extend( if input_data.get("files"):
[ result["messages"].extend(
{ [
"role": "user", {
"content": input_data.get("message", '') "role": "user",
}, "content": input_data.get("message", '')
{ },
"role": "assistant", {
"content": full_content "role": "user",
} "content": input_data.get("files")
] },
) {
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
logger.info( logger.info(
f"Workflow execution completed (streaming), " f"Workflow execution completed (streaming), "
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}" f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db from app.db import get_db_context
from app.models import AppRelease from app.models import AppRelease
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑 """准备 Agent公共逻辑
Args: Args:
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
if not agent_id: if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db()) with get_db_context() as db:
release = db.query(AppRelease).filter( release = db.query(AppRelease).filter(
AppRelease.id == agent_id AppRelease.id == agent_id
).first() ).first()
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = AgentRunService(db)
return draft_service, release, message return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行 """非流式执行
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
Returns: Returns:
状态更新字典 状态更新字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
with get_db_context() as db:
# 执行 Agent非流式 draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=release.config, # 执行 Agent非流式
model_config=None, result = await draft_service.run(
message=message, agent_config=release.config,
workspace_id=variable_pool.get_value("sys.workspace_id"), model_config=None,
user_id=state.get("user_id"), message=message,
variables=variable_pool.get_all_conversation_vars() workspace_id=variable_pool.get_value("sys.workspace_id"),
) user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
)
response = result.get("response", "") response = result.get("response", "")
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
Yields: Yields:
流式事件字典 流式事件字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式 # 执行 Agent流式
async for chunk in draft_service.run_stream( async for chunk in draft_service.run_stream(
agent_config=release.config, agent_config=release.config,
model_config=None, model_config=None,
message=message, message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"), workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"), user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars() variables=variable_pool.get_all_conversation_vars()
): ):
# 提取内容 # 提取内容
content = chunk.get("content", "") content = chunk.get("content", "")
full_response += content full_response += content
# 流式返回每个 chunk # 流式返回每个 chunk
yield { yield {
"type": "chunk", "type": "chunk",
"node_id": self.node_id, "node_id": self.node_id,
"content": content, "content": content,
"full_content": full_response, "full_content": full_response,
"meta_data": chunk.get("meta_data", {}) "meta_data": chunk.get("meta_data", {})
} }
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")

View File

@@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel):
- tags: 节点标签(用于分类和搜索) - tags: 节点标签(用于分类和搜索)
""" """
name: str | None = Field( # name: str | None = Field(
default=None, # default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID" # description="节点名称(显示名称),如果不设置则使用节点 ID"
) # )
#
description: str | None = Field( # description: str | None = Field(
default=None, # default=None,
description="节点描述,说明节点的作用" # description="节点描述,说明节点的作用"
) # )
#
tags: list[str] = Field( # tags: list[str] = Field(
default_factory=list, # default_factory=list,
description="节点标签,用于分类和搜索" # description="节点标签,用于分类和搜索"
) # )
class Config: class Config:
"""Pydantic 配置""" """Pydantic 配置"""

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime
from functools import cached_property from functools import cached_property
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
@@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
@@ -617,17 +618,31 @@ class BaseNode(ABC):
return variable_pool.has(selector) return variable_pool.has(selector)
@staticmethod @staticmethod
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None: async def process_message(
provider: str,
is_omni: bool,
content: str | dict | FileObject,
enable_file=False
) -> list | str | None:
if isinstance(content, dict):
content = FileObject(
type=content.get("type"),
url=content.get("url"),
transfer_method=content.get("transfer_method"),
origin_file_type=content.get("origin_file_type"),
file_id=content.get("file_id"),
is_file=True
)
if isinstance(content, str): if isinstance(content, str):
if enable_file: if enable_file:
return {"text": content} return [{"type": "text", "text": content}]
return content return content
elif isinstance(content, FileObject): elif isinstance(content, FileObject):
if content.content_cache.get(provider): if content.content_cache.get(provider):
return content.content_cache[provider] return content.content_cache[provider]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, provider) multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
message = await multimodel_service.process_files( message = await multimodel_service.process_files(
[FileInput.model_construct( [FileInput.model_construct(
type=content.type, type=content.type,
@@ -637,10 +652,9 @@ class BaseNode(ABC):
upload_file_id=content.file_id upload_file_id=content.file_id
)] )]
) )
if message: if message:
content.content_cache[provider] = message[0] content.content_cache[provider] = message
return message[0] return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpect input value type - {type(content)}')
@@ -658,3 +672,12 @@ class BaseNode(ABC):
elif isinstance(content, str): elif isinstance(content, str):
return content return content
return result return result
@staticmethod
def model_balance(model_config: ModelConfig) -> ModelApiKey:
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
raise ValueError("No active API keys available for model")
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
return api_keys[0]

View File

@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.notes.config import NoteNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
@@ -47,5 +48,6 @@ __all__ = [
"ToolNodeConfig", "ToolNodeConfig",
"MemoryReadNodeConfig", "MemoryReadNodeConfig",
"MemoryWriteNodeConfig", "MemoryWriteNodeConfig",
"CodeNodeConfig" "CodeNodeConfig",
"NoteNodeConfig"
] ]

View File

@@ -25,6 +25,7 @@ class NodeType(StrEnum):
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown" UNKNOWN = "unknown"
NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode):
RuntimeError: If no valid knowledge base is found or access is denied. RuntimeError: If no valid knowledge base is found or access is denied.
""" """
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
if not self.typed_config.knowledge_bases:
return []
query = self._render_template(self.typed_config.query, variable_pool) query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db: with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases knowledge_bases = self.typed_config.knowledge_bases

View File

@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据 # 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0] api_config = self.model_balance(config)
model_name = api_config.model_name model_name = api_config.model_name
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
api_base = api_config.api_base api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据) # 4. 创建 LLM 实例(使用已提取的数据)
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
extra_params=extra_params extra_params=extra_params,
is_omni=is_omni
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )
@@ -151,39 +153,53 @@ class LLMNode(BaseNode):
if role == "system": if role == "system":
messages.append({ messages.append({
"role": "system", "role": "system",
"content": content "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": content "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": content "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": content "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(provider, file.value, self.typed_config.vision) content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
if content: if content:
file_content.append(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
messages[-1]['content'] = [messages[-1]["content"]] + file_content messages[-1]['content'] = messages[-1]["content"] + file_content
else: else:
messages.append({"role": "user", "content": file_content}) messages.append({"role": "user", "content": file_content})
if self.typed_config.memory.enable: if self.typed_config.memory.enable:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] history_message = []
for message in state["messages"][-self.typed_config.memory.window_size:]:
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
{"role": message["role"], "content": file_content}
)
else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages self.messages = messages
else: else:
# 使用简单的 prompt 格式(向后兼容) # 使用简单的 prompt 格式(向后兼容)

View File

@@ -0,0 +1,12 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class NoteNodeConfig(BaseNodeConfig):
author: str = Field(default="", description="author")
text: str = Field(default="", description="note content")
width: int = Field(default=80)
height: int = Field(default=80)
theme: str = Field(default="blue")
show_author: bool = Field(default=True)

View File

@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
if not config.api_keys or len(config.api_keys) == 0: if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0] api_config = self.model_balance(config)
model_name = api_config.model_name model_name = api_config.model_name
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
api_base = api_config.api_base api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type model_type = config.type
llm = RedBearLLM( llm = RedBearLLM(
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
is_omni=is_omni
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )

View File

@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
if not config.api_keys or len(config.api_keys) == 0: if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0] api_config = self.model_balance(config)
model_name = api_config.model_name model_name = api_config.model_name
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
base_url = api_config.api_base base_url = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type model_type = config.type
return RedBearLLM( return RedBearLLM(
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
base_url=base_url, base_url=base_url,
is_omni=is_omni
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )

View File

@@ -138,7 +138,7 @@ class WorkflowValidator:
errors.append("工作流必须至少有一个 end 节点") errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性 # 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes] node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
if len(node_ids) != len(set(node_ids)): if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")

View File

@@ -3,7 +3,7 @@ import uuid
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>" return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
class MCPSourceChannel(StrEnum):
"""MCP来源渠道枚举"""
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
MODELSCOPE = "modelscope" # ModelScope
TOKENFLUX = "tokenflux" # TokenFlux
LANGENG = "langeng" # 蓝耕科技
AI_302 = "302ai" # 302.AI
MCP_ROUTER = "mcp_router" # MCP Router
SELF_HOSTED = "self_hosted" # 自建
class MCPToolConfig(Base): class MCPToolConfig(Base):
"""MCP工具配置模型""" """MCP工具配置模型"""
__tablename__ = "mcp_tool_configs" __tablename__ = "mcp_tool_configs"
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
server_url = Column(String(1000), nullable=False) # MCP服务器URL server_url = Column(String(1000), nullable=False) # MCP服务器URL
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息) connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
# 来源渠道
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
# 服务状态 # 服务状态
last_health_check = Column(DateTime) last_health_check = Column(DateTime)

View File

@@ -1,10 +1,11 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid import uuid
from typing import List
from app.models.app_model import App from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
from app.models.app_model import App
# 获取数据库专用日志器 # 获取数据库专用日志器
db_logger = get_db_logger() db_logger = get_db_logger()
@@ -35,11 +36,27 @@ class AppRepository:
except Exception as e: except Exception as e:
raise raise
def get_apps_by_name(self, app_name: str, app_type: str, workspace_id: uuid.UUID) -> List[App]:
try:
stmt = select(App).where(
App.name == app_name,
App.workspace_id == workspace_id,
App.type == app_type,
App.is_active.is_(True),
)
apps = self.db.execute(stmt).scalars().all()
return list(apps)
except Exception as e:
db_logger.error(f"查询名称 {app_name} 应用异常: {str(e)}")
raise
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]: def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用""" """根据工作空间ID查询应用"""
repo = AppRepository(db) repo = AppRepository(db)
return repo.get_apps_by_workspace_id(workspace_id) return repo.get_apps_by_workspace_id(workspace_id)
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App: def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
"""根据工作空间ID查询应用""" """根据工作空间ID查询应用"""
repo = AppRepository(db) repo = AppRepository(db)

View File

@@ -5,13 +5,22 @@ Implicit Emotions Storage Repository
事务由调用方控制,仓储层只使用 flush/refresh 事务由调用方控制,仓储层只使用 flush/refresh
""" """
import logging import logging
from datetime import datetime, date, timezone, timedelta from datetime import date, datetime, timedelta, timezone
from typing import Optional, Generator from typing import Generator, Optional
from sqlalchemy.orm import Session
from sqlalchemy import select, not_, exists
class TimeFilterUnavailableError(Exception):
"""redis_client 不可用,无法执行时间轴筛选。
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
"""
import redis
from sqlalchemy import exists, not_, select
from sqlalchemy.orm import Session
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -111,6 +120,88 @@ class ImplicitEmotionsStorageRepository:
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}") logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
break break
def get_users_needing_refresh(self, redis_client: redis.StrictRedis, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
筛选逻辑:
- 查询 implicit_emotions_storage 中所有用户的 end_user_id 和 updated_at
- 从 Redis 读取 write_message:last_done:{end_user_id} 的时间戳
- 若 Redis 中无记录(该用户从未写入过记忆),跳过
- 若 last_done > updated_at说明上次刷新后又有新记忆写入需要刷新
- 若 last_done <= updated_at说明已是最新跳过
Args:
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB
batch_size: 每批次加载的数量
Raises:
TimeFilterUnavailableError: redis_client 为 None 时抛出,调用方可捕获并回退到 get_all_user_ids
Yields:
需要刷新的用户ID字符串
"""
if redis_client is None:
raise TimeFilterUnavailableError("redis_client 不可用,无法执行时间轴筛选")
from redis.exceptions import RedisError
offset = 0
while True:
try:
stmt = (
select(ImplicitEmotionsStorage.end_user_id, ImplicitEmotionsStorage.updated_at)
.order_by(ImplicitEmotionsStorage.end_user_id)
.limit(batch_size)
.offset(offset)
)
batch = self.db.execute(stmt).all()
if not batch:
break
# 批量获取当前批次所有用户的 last_done 时间戳(一次网络往返)
keys = [f"write_message:last_done:{end_user_id}" for end_user_id, _ in batch]
try:
raw_values = redis_client.mget(keys)
except RedisError as e:
logger.error(
f"Redis mget 操作失败: {e},当前批次降级为处理所有用户",
extra={"offset": offset, "batch_size": len(batch)}
)
# Redis 操作失败,降级为返回当前批次所有用户
yield from (end_user_id for end_user_id, _ in batch)
offset += batch_size
continue
for (end_user_id, updated_at), raw in zip(batch, raw_values):
if raw is None:
continue
try:
CST = timezone(timedelta(hours=8))
last_done = datetime.fromisoformat(raw)
# last_done 写入时已是 CST naive直接使用无需转换
if last_done.tzinfo is not None:
last_done = last_done.astimezone(CST).replace(tzinfo=None)
if updated_at is None:
yield end_user_id
continue
# updated_at 数据库存的是 UTC naive转为 CST naive 再比较
if updated_at.tzinfo is None:
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
else:
updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None)
if last_done > updated_at_cst:
yield end_user_id
except Exception as e:
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
offset += batch_size
except Exception as e:
logger.error(f"get_users_needing_refresh 分批查询失败: offset={offset}, error={e}")
break
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]: def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID """分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
@@ -124,7 +215,8 @@ class ImplicitEmotionsStorageRepository:
Yields: Yields:
用户ID字符串 用户ID字符串
""" """
from sqlalchemy import cast, String as SAString from sqlalchemy import String as SAString
from sqlalchemy import cast
CST = timezone(timedelta(hours=8)) CST = timezone(timedelta(hours=8))
now_cst = datetime.now(CST) now_cst = datetime.now(CST)
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None) today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)

View File

@@ -233,6 +233,7 @@ class MemoryConfigRepository:
config_desc=params.config_desc, config_desc=params.config_desc,
workspace_id=params.workspace_id, workspace_id=params.workspace_id,
scene_id=params.scene_id, scene_id=params.scene_id,
pruning_scene=params.pruning_scene,
llm_id=params.llm_id, llm_id=params.llm_id,
embedding_id=params.embedding_id, embedding_id=params.embedding_id,
rerank_id=params.rerank_id, rerank_id=params.rerank_id,

View File

@@ -374,7 +374,7 @@ class OntologySceneRepository:
count = self.db.query(OntologyScene).filter( count = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id, OntologyScene.scene_id == scene_id,
OntologyScene.workspace_id == workspace_id (OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
).count() ).count()
is_owner = count > 0 is_owner = count > 0

View File

@@ -1,10 +1,13 @@
from sqlalchemy.orm import Session, joinedload
from app.models.user_model import User
from typing import List, Optional
import uuid import uuid
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole from typing import List, Optional
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import select
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
from app.models.user_model import User
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
from app.schemas.workspace_schema import WorkspaceCreate
# 获取数据库专用日志器 # 获取数据库专用日志器
db_logger = get_db_logger() db_logger = get_db_logger()
@@ -19,7 +22,7 @@ class WorkspaceRepository:
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace: def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
"""创建工作空间""" """创建工作空间"""
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}") db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
try: try:
db_workspace = Workspace( db_workspace = Workspace(
name=workspace_data.name, name=workspace_data.name,
@@ -34,7 +37,8 @@ class WorkspaceRepository:
) )
self.db.add(db_workspace) self.db.add(db_workspace)
self.db.flush() self.db.flush()
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}") db_logger.info(
f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
return db_workspace return db_workspace
except Exception as e: except Exception as e:
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}") db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
@@ -43,7 +47,7 @@ class WorkspaceRepository:
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]: def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
"""根据ID获取工作空间""" """根据ID获取工作空间"""
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}") db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
try: try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first() workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace: if workspace:
@@ -65,7 +69,7 @@ class WorkspaceRepository:
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
""" """
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}") db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
try: try:
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first() workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace: if workspace:
@@ -89,7 +93,7 @@ class WorkspaceRepository:
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]: def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)""" """获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}") db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
try: try:
# 首先获取用户信息以获取 tenant_id # 首先获取用户信息以获取 tenant_id
from app.models.user_model import User from app.models.user_model import User
@@ -97,7 +101,7 @@ class WorkspaceRepository:
if not user: if not user:
db_logger.warning(f"用户不存在: user_id={user_id}") db_logger.warning(f"用户不存在: user_id={user_id}")
return [] return []
if user.is_superuser: if user.is_superuser:
# 超级用户获取对应tenantid所有工作空间 # 超级用户获取对应tenantid所有工作空间
workspaces = ( workspaces = (
@@ -109,7 +113,7 @@ class WorkspaceRepository:
) )
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}") db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
return workspaces return workspaces
# 获取用户作为成员的工作空间 # 获取用户作为成员的工作空间
member_workspaces = ( member_workspaces = (
self.db.query(Workspace) self.db.query(Workspace)
@@ -120,7 +124,7 @@ class WorkspaceRepository:
.order_by(Workspace.updated_at.desc()) .order_by(Workspace.updated_at.desc())
.all() .all()
) )
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}") db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
return member_workspaces return member_workspaces
except Exception as e: except Exception as e:
@@ -130,7 +134,7 @@ class WorkspaceRepository:
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]: def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
"""获取租户的所有工作空间""" """获取租户的所有工作空间"""
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}") db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
try: try:
workspaces = ( workspaces = (
self.db.query(Workspace) self.db.query(Workspace)
@@ -144,14 +148,32 @@ class WorkspaceRepository:
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}") db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
raise raise
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember: def get_workspaces_by_name(self, tenant_id: uuid.UUID, workspace_name: str) -> List[Workspace]:
try:
stmt = (
select(Workspace)
.where(
Workspace.tenant_id == tenant_id,
Workspace.name == workspace_name,
Workspace.is_active.is_(True)
)
)
workspaces = self.db.execute(stmt).scalars().all()
return list(workspaces)
except Exception as e:
db_logger.error(f"查询工作空间失败: workspace_name={workspace_name} - {str(e)}")
raise
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID,
role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
"""添加工作空间成员""" """添加工作空间成员"""
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}") db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
try: try:
db_member = WorkspaceMember( db_member = WorkspaceMember(
user_id=user_id, user_id=user_id,
workspace_id=workspace_id, workspace_id=workspace_id,
role=role role=role
) )
self.db.add(db_member) self.db.add(db_member)
@@ -165,7 +187,7 @@ class WorkspaceRepository:
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]: def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
"""获取工作空间成员""" """获取工作空间成员"""
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}") db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.user_id == user_id, WorkspaceMember.user_id == user_id,
@@ -173,7 +195,8 @@ class WorkspaceRepository:
WorkspaceMember.is_active.is_(True), WorkspaceMember.is_active.is_(True),
).first() ).first()
if member: if member:
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}") db_logger.debug(
f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
else: else:
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}") db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
return member return member
@@ -199,7 +222,7 @@ class WorkspaceRepository:
except Exception as e: except Exception as e:
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}") db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
raise raise
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember: def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
"""按成员ID获取工作空间成员并预加载 user 与 workspace 关系""" """按成员ID获取工作空间成员并预加载 user 与 workspace 关系"""
db_logger.debug(f"查询成员的工作空间: member_id={member_id}") db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
@@ -214,7 +237,8 @@ class WorkspaceRepository:
.first() .first()
) )
if member: if member:
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}") db_logger.debug(
f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
else: else:
db_logger.debug(f"成员不存在: member_id={member_id}") db_logger.debug(f"成员不存在: member_id={member_id}")
return member return member
@@ -222,7 +246,8 @@ class WorkspaceRepository:
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}") db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
raise raise
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]: def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[
WorkspaceMember]:
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.workspace_id == workspace_id,
@@ -255,7 +280,7 @@ class WorkspaceRepository:
except Exception as e: except Exception as e:
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}") db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
raise raise
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]: def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
@@ -271,7 +296,7 @@ class WorkspaceRepository:
except Exception as e: except Exception as e:
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}") db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
raise raise
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]: def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
@@ -288,12 +313,18 @@ class WorkspaceRepository:
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}") db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
raise raise
# 保持向后兼容的函数 # 保持向后兼容的函数
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None: def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.get_workspace_by_id(workspace_id) return repo.get_workspace_by_id(workspace_id)
def get_workspaces_by_name(db: Session, tenant_id: uuid.UUID, name: str) -> List[Workspace]:
repo = WorkspaceRepository(db)
return repo.get_workspaces_by_name(tenant_id, name)
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]: def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.get_workspaces_by_user(user_id) return repo.get_workspaces_by_user(user_id)
@@ -315,7 +346,7 @@ def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UU
def add_member_to_workspace( def add_member_to_workspace(
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
) -> WorkspaceMember: ) -> WorkspaceMember:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.add_member(workspace_id, user_id, role) return repo.add_member(workspace_id, user_id, role)
@@ -325,39 +356,43 @@ def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[Works
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.get_members_by_workspace(workspace_id) return repo.get_members_by_workspace(workspace_id)
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None: def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.get_member_by_id(member_id) return repo.get_member_by_id(member_id)
def update_member_role_in_workspace( def update_member_role_in_workspace(
db: Session, db: Session,
user_id: uuid.UUID, user_id: uuid.UUID,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
role: WorkspaceRole, role: WorkspaceRole,
) -> Optional[WorkspaceMember]: ) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.update_member_role(workspace_id, user_id, role) return repo.update_member_role(workspace_id, user_id, role)
def remove_member_from_workspace( def remove_member_from_workspace(
db: Session, db: Session,
user_id: uuid.UUID, user_id: uuid.UUID,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
) -> Optional[WorkspaceMember]: ) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.deactivate_member(workspace_id, user_id) return repo.deactivate_member(workspace_id, user_id)
def remove_member_from_workspace_by_id( def remove_member_from_workspace_by_id(
db: Session, db: Session,
member_id: uuid.UUID, member_id: uuid.UUID,
) -> Optional[WorkspaceMember]: ) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.delete_member_by_id(member_id) return repo.delete_member_by_id(member_id)
def update_member_role_by_id( def update_member_role_by_id(
db: Session, db: Session,
id: uuid.UUID, id: uuid.UUID,
role: WorkspaceRole, role: WorkspaceRole,
) -> Optional[WorkspaceMember]: ) -> Optional[WorkspaceMember]:
repo = WorkspaceRepository(db) repo = WorkspaceRepository(db)
return repo.update_member_role_by_id(id, role) return repo.update_member_role_by_id(id, role)

View File

@@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel):
type: ApiKeyType = Field(..., description="API Key 类型") type: ApiKeyType = Field(..., description="API Key 类型")
scopes: List[str] = Field(default_factory=list, description="权限范围列表") scopes: List[str] = Field(default_factory=list, description="权限范围列表")
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制请求/秒)") rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1) daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")

View File

@@ -86,6 +86,7 @@ class ChatResponse(BaseModel):
"""聊天响应(非流式)""" """聊天响应(非流式)"""
conversation_id: uuid.UUID conversation_id: uuid.UUID
message: str message: str
message_id: str
usage: Optional[Dict[str, Any]] = None usage: Optional[Dict[str, Any]] = None
elapsed_time: Optional[float] = None elapsed_time: Optional[float] = None

View File

@@ -417,6 +417,7 @@ class MemoryConfig:
# Ontology scene association # Ontology scene association
scene_id: Optional[UUID] = None scene_id: Optional[UUID] = None
ontology_classes: Optional[list] = field(default=None)
def __post_init__(self): def __post_init__(self):
"""Validate configuration after initialization.""" """Validate configuration after initialization."""

View File

@@ -232,14 +232,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
# 本体场景关联(可选) # 本体场景关联(可选)
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表") scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表")
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name前端无需传入
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
# 模型配置字段(可选,用于手动指定或自动填充) # 模型配置字段(可选,用于手动指定或自动填充)
llm_id: Optional[str] = Field(None, description="LLM模型配置ID") llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致") reflection_model_id: Optional[str] = Field(None, description="反思模型ID默认与llm_id一致")
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID默认与llm_id一致")
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
# config_name: str = Field("配置名称", description="配置名称(字符串)") # config_name: str = Field("配置名称", description="配置名称(字符串)")
@@ -274,8 +275,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
# 剪枝配置:与 runtime.json 中 pruning 段对应 # 剪枝配置:与 runtime.json 中 pruning 段对应
pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝") pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝")
pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field( pruning_scene: Optional[str] = Field(
None, description="智能剪枝场景education/online_service/outbound" None, description="智能剪枝场景education/online_service/outbound 或本体工程自定义场景"
) )
pruning_threshold: Optional[float] = Field( pruning_threshold: Optional[float] = Field(
None, ge=0.0, le=0.9, description="智能语义剪枝阈值0-0.9" None, ge=0.0, le=0.9, description="智能语义剪枝阈值0-0.9"

View File

@@ -23,6 +23,7 @@ class ModelConfigBase(BaseModel):
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略") load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
capability: List[str] = Field(default_factory=list, description="模型能力列表") capability: List[str] = Field(default_factory=list, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型") is_omni: bool = Field(False, description="是否为Omni模型")
model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID")
class ApiKeyCreateNested(BaseModel): class ApiKeyCreateNested(BaseModel):
@@ -116,8 +117,8 @@ class ModelApiKeyBase(BaseModel):
provider: ModelProvider = Field(..., description="API Key提供商") provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500) api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: List[str] = Field(default_factory=list, description="模型能力列表") capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型") is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活") is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10) priority: str = Field("1", description="优先级", max_length=10)

View File

@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
classes_count: int = Field(0, description="类型数量") classes_count: int = Field(0, description="类型数量")
is_system_default: bool = Field(False, description="是否为系统默认场景")
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
scene_id: UUID = Field(..., description="所属场景ID") scene_id: UUID = Field(..., description="所属场景ID")
scene_name: str = Field(..., description="场景名称") scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述") scene_description: Optional[str] = Field(None, description="场景描述")
is_system_default: bool = Field(False, description="是否为系统默认场景")
items: List[ClassResponse] = Field(..., description="类型列表") items: List[ClassResponse] = Field(..., description="类型列表")

View File

@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
health_status: str = "unknown" health_status: str = "unknown"
error_message: Optional[str] = None error_message: Optional[str] = None
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]") available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
source_channel: Optional[str] = Field(None, description="来源渠道")
market_id: Optional[str] = Field(None, description="渠道市场id")
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
class Config: class Config:
from_attributes = True from_attributes = True
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
tool_type: ToolType tool_type: ToolType
config: Dict[str, Any] = Field(default_factory=dict) config: Dict[str, Any] = Field(default_factory=dict)
tags: List[str] = Field(default_factory=list) tags: List[str] = Field(default_factory=list)
source_channel: Optional[str] = Field(None, description="来源渠道仅MCP工具")
market_id: Optional[str] = Field(None, description="渠道市场id仅MCP工具")
market_config_id: Optional[str] = Field(None, description="渠道市场配置id仅MCP工具")
mcp_service_id: Optional[str] = Field(None, description="mcp服务id仅MCP工具")
class ToolUpdateRequest(BaseModel): class ToolUpdateRequest(BaseModel):

View File

@@ -144,7 +144,7 @@ class AppChatService:
) )
# 保存消息 # 保存消息
self.conversation_service.save_conversation_messages( message_id = self.conversation_service.save_conversation_messages(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=message, user_message=message,
assistant_message=result["content"], assistant_message=result["content"],
@@ -163,6 +163,7 @@ class AppChatService:
return { return {
"conversation_id": conversation_id, "conversation_id": conversation_id,
"message_id": str(message_id),
"message": result["content"], "message": result["content"],
"usage": result.get("usage", { "usage": result.get("usage", {
"prompt_tokens": 0, "prompt_tokens": 0,
@@ -191,7 +192,11 @@ class AppChatService:
try: try:
start_time = time.time() start_time = time.time()
config_id = None config_id = None
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" message_id = uuid.uuid4()
yield f"event: start\ndata: {json.dumps({
'conversation_id': str(conversation_id),
"message_id": str(message_id)
}, ensure_ascii=False)}\n\n"
variables = self.agent_service.prepare_variables(variables, config.variables) variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID # 获取模型配置ID
@@ -296,6 +301,7 @@ class AppChatService:
) )
self.conversation_service.add_message( self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=full_content, content=full_content,
@@ -373,7 +379,7 @@ class AppChatService:
content=message content=message
) )
self.conversation_service.add_message( ai_message = self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=result.get("message", ""), content=result.get("message", ""),
@@ -391,6 +397,7 @@ class AppChatService:
return { return {
"conversation_id": conversation_id, "conversation_id": conversation_id,
"message": result.get("message", ""), "message": result.get("message", ""),
"message_id": str(ai_message.id),
"usage": { "usage": {
"prompt_tokens": 0, "prompt_tokens": 0,
"completion_tokens": 0, "completion_tokens": 0,
@@ -419,9 +426,9 @@ class AppChatService:
variables = {} variables = {}
try: try:
message_id = uuid.uuid4()
# 发送开始事件 # 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
full_content = "" full_content = ""
total_tokens = 0 total_tokens = 0
@@ -429,6 +436,7 @@ class AppChatService:
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务 # 3. 流式执行任务
async for event in orchestrator.execute_stream( async for event in orchestrator.execute_stream(
message=message, message=message,
@@ -472,6 +480,7 @@ class AppChatService:
) )
self.conversation_service.add_message( self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=full_content, content=full_content,

View File

@@ -0,0 +1,390 @@
"""应用 DSL 导入导出服务"""
import uuid
import datetime
from typing import Optional
import yaml
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.models import AgentConfig, MultiAgentConfig
from app.models.app_model import App, AppType
from app.models.app_release_model import AppRelease
from app.models.knowledge_model import Knowledge
from app.models.models_model import ModelConfig
from app.models.tool_model import ToolConfig as ToolConfigModel
from app.models.workflow_model import WorkflowConfig
from app.services.workflow_service import WorkflowService
class AppDslService:
def __init__(self, db: Session):
self.db = db
# ==================== 导出 ====================
def export_dsl(self, app_id: uuid.UUID, release_id: Optional[uuid.UUID] = None) -> tuple[str, str]:
"""构建应用 DSL yaml 字符串,返回 (yaml_str, filename)"""
app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
if not app:
raise ResourceNotFoundException("应用", str(app_id))
meta = {
"version": settings.SYSTEM_VERSION,
"platform": "MemoryBear",
"exported_at": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
}
app_meta = {
"name": app.name,
"description": app.description,
"icon": app.icon,
"icon_type": app.icon_type,
"type": app.type,
"tags": app.tags or [],
}
if release_id is not None:
return self._export_release(app, release_id, meta, app_meta)
return self._export_draft(app, meta, app_meta)
def _export_release(self, app: App, release_id: uuid.UUID, meta: dict, app_meta: dict) -> tuple[str, str]:
release = self.db.query(AppRelease).filter(
AppRelease.app_id == app.id,
AppRelease.id == release_id,
AppRelease.is_active.is_(True)
).first()
if not release:
raise ResourceNotFoundException("版本", str(release_id))
meta["release_version"] = release.version
meta["release_name"] = release.version_name
app_meta["name"] = release.name
app_meta["description"] = release.description
config_key = {
AppType.AGENT: "agent_config",
AppType.MULTI_AGENT: "multi_agent_config",
AppType.WORKFLOW: "workflow"
}.get(app.type, "config")
config_data = self._enrich_release_config(app.type, release.config or {})
dsl = {**meta, "app": app_meta, config_key: config_data}
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
if app_type == AppType.AGENT:
enriched = {**cfg}
if "default_model_config_id" in cfg:
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
if "knowledge_retrieval" in cfg:
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
if "tools" in cfg:
enriched["tools"] = self._enrich_tools(cfg["tools"])
return enriched
if app_type == AppType.MULTI_AGENT:
enriched = {**cfg}
if "default_model_config_id" in cfg:
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
if "master_agent_id" in cfg:
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
if "sub_agents" in cfg:
enriched["sub_agents"] = self._enrich_sub_agents(cfg["sub_agents"])
if "routing_rules" in cfg:
enriched["routing_rules"] = [
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
]
return enriched
return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
if app.type == AppType.WORKFLOW:
config = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app.id).first()
config_data = {
"variables": config.variables if config else [],
"edges": config.edges if config else [],
"nodes": config.nodes if config else [],
"execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [],
} if config else {}
dsl = {**meta, "app": app_meta, "workflow": config_data}
elif app.type == AppType.AGENT:
config = self.db.query(AgentConfig).filter(AgentConfig.app_id == app.id).first()
config_data = {
"system_prompt": config.system_prompt if config else None,
"model_parameters": self._to_dict(config.model_parameters) if config else None,
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
"knowledge_retrieval": self._enrich_knowledge_retrieval(config.knowledge_retrieval) if config else None,
"memory": config.memory if config else None,
"variables": config.variables if config else [],
"tools": self._enrich_tools(config.tools) if config else [],
"skills": config.skills if config else {},
} if config else {}
dsl = {**meta, "app": app_meta, "agent_config": config_data}
elif app.type == AppType.MULTI_AGENT:
config = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app.id).first()
config_data = {
"orchestration_mode": config.orchestration_mode if config else None,
"master_agent_name": config.master_agent_name if config else None,
"model_parameters": self._to_dict(config.model_parameters) if config else None,
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
"master_agent_ref": self._release_ref(config.master_agent_id) if config else None,
"sub_agents": self._enrich_sub_agents(config.sub_agents) if config else [],
"routing_rules": [
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (config.routing_rules or [])
] if config else [],
"execution_config": config.execution_config if config else {},
"aggregation_strategy": config.aggregation_strategy if config else "merge",
} if config else {}
dsl = {**meta, "app": app_meta, "multi_agent_config": config_data}
else:
raise BusinessException(f"不支持的应用类型: {app.type}", BizCode.BAD_REQUEST)
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{app.name}.yaml"
def _to_dict(self, value):
"""将 Pydantic 对象转为普通 dict供 yaml.dump 安全序列化"""
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump()
return value
def _model_ref(self, model_config_id) -> Optional[dict]:
if not model_config_id:
return None
m = self.db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
return {"id": str(model_config_id), "name": m.name, "provider": m.provider, "type": m.type} if m else {"id": str(model_config_id)}
def _kb_ref(self, kb_id) -> Optional[dict]:
if not kb_id:
return None
kb = self.db.query(Knowledge).filter(Knowledge.id == kb_id).first()
return {"id": str(kb_id), "name": kb.name} if kb else {"id": str(kb_id)}
def _tool_ref(self, tool_id) -> Optional[dict]:
if not tool_id:
return None
t = self.db.query(ToolConfigModel).filter(ToolConfigModel.id == tool_id).first()
return {"id": str(tool_id), "name": t.name, "tool_type": t.tool_type} if t else {"id": str(tool_id)}
def _enrich_knowledge_retrieval(self, kr: Optional[dict]) -> Optional[dict]:
if not kr:
return kr
kbs = [{**kb, "_ref": self._kb_ref(kb.get("kb_id"))} for kb in kr.get("knowledge_bases", [])]
return {**kr, "knowledge_bases": kbs}
def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _agent_ref(self, agent_id) -> Optional[dict]:
if not agent_id:
return None
a = self.db.query(App).filter(App.id == agent_id).first()
return {"id": str(agent_id), "name": a.name} if a else {"id": str(agent_id)}
def _release_ref(self, release_id) -> Optional[dict]:
if not release_id:
return None
r = self.db.query(AppRelease).filter(AppRelease.id == release_id).first()
return {"id": str(release_id), "name": r.name, "version": r.version, "app_id": str(r.app_id)} if r else {"id": str(release_id)}
def _enrich_sub_agents(self, sub_agents: list) -> list:
return [{**s, "_ref": self._agent_ref(s.get("agent_id"))} for s in (sub_agents or [])]
# ==================== 导入 ====================
def import_dsl(
self,
dsl: dict,
workspace_id: uuid.UUID,
tenant_id: uuid.UUID,
user_id: uuid.UUID,
) -> tuple[App, list[str]]:
"""解析 DSL创建应用及配置返回 (new_app, warnings)"""
app_meta = dsl.get("app", {})
app_type = app_meta.get("type")
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.BAD_REQUEST)
warnings: list[str] = []
now = datetime.datetime.now()
new_app = App(
id=uuid.uuid4(),
workspace_id=workspace_id,
created_by=user_id,
name=app_meta.get("name", "导入应用"),
description=app_meta.get("description"),
icon=app_meta.get("icon"),
icon_type=app_meta.get("icon_type"),
type=app_type,
visibility="private",
status="draft",
tags=app_meta.get("tags", []),
is_active=True,
created_at=now,
updated_at=now,
)
self.db.add(new_app)
self.db.flush()
if app_type == AppType.AGENT:
cfg = dsl.get("agent_config") or {}
self.db.add(AgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
system_prompt=cfg.get("system_prompt"),
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
memory=cfg.get("memory"),
variables=cfg.get("variables", []),
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
skills=cfg.get("skills", {}),
is_active=True,
created_at=now,
updated_at=now,
))
elif app_type == AppType.MULTI_AGENT:
cfg = dsl.get("multi_agent_config") or {}
self.db.add(MultiAgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
master_agent_name=cfg.get("master_agent_name"),
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
master_agent_id=self._resolve_release(cfg.get("master_agent_ref"), warnings),
sub_agents=self._resolve_sub_agents(cfg.get("sub_agents", []), warnings),
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
execution_config=cfg.get("execution_config", {}),
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
is_active=True,
created_at=now,
updated_at=now,
))
elif app_type == AppType.WORKFLOW:
wf = dsl.get("workflow") or {}
WorkflowService(self.db).create_workflow_config(
app_id=new_app.id,
nodes=wf.get("nodes", []),
edges=wf.get("edges", []),
variables=wf.get("variables", []),
execution_config=wf.get("execution_config", {}),
triggers=wf.get("triggers", []),
validate=False,
)
self.db.commit()
self.db.refresh(new_app)
return new_app, warnings
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
if not ref:
return None
q = self.db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == ref.get("name"),
ModelConfig.is_active.is_(True)
)
if ref.get("provider"):
q = q.filter(ModelConfig.provider == ref["provider"])
if ref.get("type"):
q = q.filter(ModelConfig.type == ref["type"])
m = q.first()
if not m:
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return m.id if m else None
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
return None
kb = self.db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.name == ref.get("name")
).first()
if not kb:
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return str(kb.id) if kb else None
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
return None
q = self.db.query(ToolConfigModel).filter(
ToolConfigModel.tenant_id == tenant_id,
ToolConfigModel.name == ref.get("name")
)
if ref.get("tool_type"):
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
t = q.first()
if not t:
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return str(t.id) if t else None
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
if not ref:
return None
r = self.db.query(AppRelease).filter(
AppRelease.app_id == ref.get("app_id"),
AppRelease.version == ref.get("version"),
AppRelease.is_active.is_(True)
).first()
if not r:
warnings.append(f"主 Agent 发布版本 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return r.id if r else None
def _resolve_sub_agents(self, sub_agents: list, warnings: list) -> list:
result = []
for s in (sub_agents or []):
ref = s.get("_ref")
entry = {k: v for k, v in s.items() if k != "_ref"}
if ref:
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
if not a:
warnings.append(f"子 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
entry["agent_id"] = str(a.id) if a else None
result.append(entry)
return result
def _resolve_routing_rules(self, rules: Optional[list], warnings: list) -> Optional[list]:
if rules is None:
return None
result = []
for r in rules:
ref = r.get("_ref")
entry = {k: v for k, v in r.items() if k != "_ref"}
if ref:
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
if not a:
warnings.append(f"路由目标 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
entry["target_agent_id"] = str(a.id) if a else None
result.append(entry)
return result
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
if not kr:
return kr
resolved_kbs = []
for kb in kr.get("knowledge_bases", []):
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
entry = {k: v for k, v in kb.items() if k != "_ref"}
entry["kb_id"] = self._resolve_kb(ref, workspace_id, warnings)
resolved_kbs.append(entry)
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
result = []
for t in (tools or []):
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
entry = {k: v for k, v in t.items() if k != "_ref"}
entry["tool_id"] = self._resolve_tool(ref, tenant_id, warnings)
result.append(entry)
return result

View File

@@ -33,7 +33,7 @@ from app.models import (
Workspace, Workspace,
) )
from app.models.app_model import AppStatus, AppType from app.models.app_model import AppStatus, AppType
from app.repositories.app_repository import get_apps_by_id from app.repositories.app_repository import get_apps_by_id, AppRepository
from app.repositories.workflow_repository import WorkflowConfigRepository from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import app_schema from app.schemas import app_schema
from app.schemas.workflow_schema import WorkflowConfigUpdate from app.schemas.workflow_schema import WorkflowConfigUpdate
@@ -59,6 +59,7 @@ class AppService:
db: 数据库会话 db: 数据库会话
""" """
self.db = db self.db = db
self.app_repo = AppRepository(self.db)
# ==================== 私有辅助方法 ==================== # ==================== 私有辅助方法 ====================
@@ -521,6 +522,9 @@ class AppService:
"创建应用", "创建应用",
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)} extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
) )
apps = self.app_repo.get_apps_by_name(data.name, data.type, workspace_id)
if apps:
raise BusinessException(message="已存在同名应用", code=BizCode.RESOURCE_ALREADY_EXISTS)
try: try:
now = datetime.datetime.now() now = datetime.datetime.now()
@@ -703,7 +707,7 @@ class AppService:
self.db.flush() self.db.flush()
# 如果是 agent 类型,复制 AgentConfig # 如果是 agent 类型,复制 AgentConfig
if source_app.type == "agent": if source_app.type == AppType.AGENT:
source_config = self.db.query(AgentConfig).filter( source_config = self.db.query(AgentConfig).filter(
AgentConfig.app_id == source_app.id AgentConfig.app_id == source_app.id
).first() ).first()
@@ -725,6 +729,50 @@ class AppService:
) )
self.db.add(new_config) self.db.add(new_config)
elif source_app.type == AppType.WORKFLOW:
source_config = self.db.query(WorkflowConfig).filter(
WorkflowConfig.app_id == source_app.id
).first()
if source_config:
new_config = WorkflowConfig(
id=uuid.uuid4(),
app_id=new_app.id,
nodes=source_config.nodes.copy() if source_config.nodes else [],
edges=source_config.edges.copy() if source_config.edges else [],
variables=source_config.variables.copy() if source_config.variables else [],
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
triggers=source_config.triggers.copy() if source_config.triggers else [],
is_active=True,
created_at=now,
updated_at=now,
)
self.db.add(new_config)
elif source_app.type == AppType.MULTI_AGENT:
source_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == source_app.id
).first()
if source_config:
new_config = MultiAgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
master_agent_id=source_config.master_agent_id,
master_agent_name=source_config.master_agent_name,
default_model_config_id=source_config.default_model_config_id,
model_parameters=source_config.model_parameters,
orchestration_mode=source_config.orchestration_mode,
sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [],
routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None,
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
aggregation_strategy=source_config.aggregation_strategy,
is_active=True,
created_at=now,
updated_at=now,
)
self.db.add(new_config)
self.db.commit() self.db.commit()
self.db.refresh(new_app) self.db.refresh(new_app)
@@ -1324,6 +1372,15 @@ class AppService:
if not agent_cfg: if not agent_cfg:
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
miss_params = []
if agent_cfg.default_model_config_id is None:
miss_params.append("model config")
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
miss_params.append("memory config")
if miss_params:
raise BusinessException(f"{', '.join(miss_params)} is required")
config = { config = {
"system_prompt": agent_cfg.system_prompt, "system_prompt": agent_cfg.system_prompt,
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters), "model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),

View File

@@ -178,7 +178,8 @@ class ConversationService:
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
role: str, role: str,
content: str, content: str,
meta_data: Optional[dict] = None meta_data: Optional[dict] = None,
message_id: Optional[uuid.UUID] = None,
) -> Message: ) -> Message:
""" """
Add a message to a conversation using UnitOfWork. Add a message to a conversation using UnitOfWork.
@@ -188,6 +189,7 @@ class ConversationService:
role (str): Role of the message sender ('user' or 'assistant'). role (str): Role of the message sender ('user' or 'assistant').
content (str): Message content. content (str): Message content.
meta_data (Optional[dict]): Optional metadata. meta_data (Optional[dict]): Optional metadata.
message_id (Optional[uuid.UUID]): Optional custom message UUID.
Returns: Returns:
Message: Newly created Message instance. Message: Newly created Message instance.
@@ -198,6 +200,7 @@ class ConversationService:
) )
message = Message( message = Message(
id=message_id if message_id else uuid.uuid4(),
conversation_id=conversation_id, conversation_id=conversation_id,
role=role, role=role,
content=content, content=content,
@@ -317,7 +320,7 @@ class ConversationService:
content=user_message content=user_message
) )
self.add_message( ai_message = self.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=assistant_message, content=assistant_message,
@@ -332,6 +335,7 @@ class ConversationService:
"assistant_message_length": len(assistant_message) "assistant_message_length": len(assistant_message)
} }
) )
return ai_message.id
def delete_conversation( def delete_conversation(
self, self,

View File

@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
""" """
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}") logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try: try:
from app.db import get_db with get_db_context() as db:
db = next(get_db())
try:
memory_content = asyncio.run( memory_content = asyncio.run(
MemoryAgentService().read_memory( MemoryAgentService().read_memory(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
logger.info(f"读取任务状态:{status}") logger.info(f"读取任务状态:{status}")
if memory_content: if memory_content:
memory_content = memory_content['answer'] memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}') logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
""" """
import json import json
import os import os
import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results, reorder_output_results,
) )
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)}) duration=duration, details={"message_length": len(message)})
return context return context
else: else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool: def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event""" """Extract tool call information from event"""
last_message = event["messages"][-1] last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content'])) langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant': elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content'])) langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-') print(100 * '-')
print(langchain_messages) print(langchain_messages)
print(100*'-') print(100 * '-')
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result') contents = massages.get('write_result')
# Convert messages back to string for logging # Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
async def read_memory( async def read_memory(
self, self,
end_user_id: str, end_user_id: str,
message: str, message: str,
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: Optional[uuid.UUID]|int, config_id: Optional[uuid.UUID] | int,
db: Session, db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str) -> Dict: user_rag_memory_id: str) -> Dict:
""" """
Process read operation with config_id Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message= message ori_message = message
# Resolve config_id and workspace_id # Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided # Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -562,34 +560,35 @@ class MemoryAgentService:
from app.repositories.memory_short_repository import ( from app.repositories.memory_short_repository import (
ShortTermMemoryRepository, ShortTermMemoryRepository,
) )
retrieved_content = [] retrieved_content = []
repo = ShortTermMemoryRepository(db) repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2": if str(search_switch) != "2":
for intermediate in _intermediate_outputs: for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}") logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '') intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result": if intermediate_type == "search_result":
query = intermediate.get('query', '') query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {}) raw_results = intermediate.get('raw_results', {})
try: try:
reranked_results = raw_results.get('reranked_results', []) reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])] statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception: except Exception:
statements = [] statements = []
# 去重 # 去重
statements = list(set(statements)) statements = list(set(statements))
if query and statements: if query and statements:
retrieved_content.append({query: statements}) retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串 # 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []: if retrieved_content == []:
retrieved_content = '' retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存 # 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法 # 使用 upsert 方法
@@ -602,15 +601,17 @@ class MemoryAgentService:
) )
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else: else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error: except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误 # 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation # Log successful operation
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
""" """
Get standardized message list from user input. Get standardized message list from user input.
@@ -657,41 +657,43 @@ class MemoryAgentService:
""" """
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
logger = get_api_logger() logger = get_api_logger()
if len(user_input.messages) == 0: if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty") logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty") raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages): for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict): if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg: if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg: if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']: if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip(): if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty") logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages return user_input.messages
async def classify_message_type( async def classify_message_type(
self, self,
message: str, message: str,
config_id: UUID, config_id: UUID,
db: Session, db: Session,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> Dict: ) -> Dict:
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id) status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}") logger.debug(f"Message type: {status}")
return status return status
async def generate_summary_from_retrieve( async def generate_summary_from_retrieve(
self, self,
end_user_id: str, end_user_id: str,
retrieve_info: str, retrieve_info: str,
history: List[Dict], history: List[Dict],
query: str, query: str,
config_id: str, config_id: str,
db: Session db: Session
) -> str: ) -> str:
""" """
基于检索信息、历史对话和查询生成最终答案 基于检索信息、历史对话和查询生成最终答案
@@ -761,9 +764,9 @@ class MemoryAgentService:
if config_id is None: if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback # If config_id was provided, continue without workspace_id fallback
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try: try:
# 加载配置 # 加载配置
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
@@ -772,7 +775,7 @@ class MemoryAgentService:
workspace_id=workspace_id, workspace_id=workspace_id,
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
# 导入必要的模块 # 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm, summary_llm,
@@ -780,13 +783,13 @@ class MemoryAgentService:
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
) )
# 构建状态对象 # 构建状态对象
state = { state = {
"data": query, "data": query,
"memory_config": memory_config "memory_config": memory_config
} }
# 直接调用 summary_llm 函数 # 直接调用 summary_llm 函数
answer = await summary_llm( answer = await summary_llm(
state=state, state=state,
@@ -797,21 +800,20 @@ class MemoryAgentService:
response_model=RetrieveSummaryResponse, response_model=RetrieveSummaryResponse,
search_mode="1" search_mode="1"
) )
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。" return answer if answer else "信息不足,无法回答。"
except Exception as e: except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True) logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。" return "信息不足,无法回答。"
async def get_knowledge_type_stats( async def get_knowledge_type_stats(
self, self,
end_user_id: Optional[str] = None, db: Session,
only_active: bool = True, end_user_id: Optional[str] = None,
current_workspace_id: Optional[uuid.UUID] = None, only_active: bool = True,
db: Session = None current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
统计知识库类型分布,包含: 统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型 # 1. 统计 PostgreSQL 中的知识库类型
try: try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0 # 初始化所有标准类型为 0
for kb_type in KnowledgeType: for kb_type in KnowledgeType:
result[kb_type.value] = 0 result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory # 3. 计算知识库类型总和(不包括 memory
result["total"] = ( result["total"] = (
result.get("General", 0) + result.get("General", 0) +
result.get("Web", 0) + result.get("Web", 0) +
result.get("Third-party", 0) + result.get("Third-party", 0) +
result.get("Folder", 0) result.get("Folder", 0)
) )
return result return result
async def get_interest_distribution_by_user( async def get_interest_distribution_by_user(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 5, limit: int = 5,
language: str = "zh" language: str = "zh"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取指定用户的兴趣分布标签。 获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}") logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}") raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile( async def get_user_profile(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None, current_user_id: Optional[str] = None,
llm_id: Optional[str] = None, llm_id: Optional[str] = None,
db: Session = None db: Session = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构 # 定义标签提取的结构
class UserTags(BaseModel): class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友") tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [ messages = [
{ {
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时 ValueError: 当终端用户不存在或应用未发布时
""" """
import json as json_module import json as json_module
import uuid
from sqlalchemy import select from sqlalchemy import select
@@ -1171,6 +1165,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
logger.info(f"Getting connected config for end_user: {end_user_id}") logger.info(f"Getting connected config for end_user: {end_user_id}")
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
# 1. 获取 end_user 及其 app_id # 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user: if not end_user:
@@ -1185,21 +1180,21 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
if not app: if not app:
logger.warning(f"App not found: {app_id}") logger.warning(f"App not found: {app_id}")
raise ValueError(f"应用不存在: {app_id}") raise ValueError(f"应用不存在: {app_id}")
# TODO: temp fix for draft run
if not app.current_release_id: # if not app.current_release_id:
logger.warning(f"No current release for app: {app_id}") # logger.warning(f"No current release for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}") # raise ValueError(f"应用未发布: {app_id}")
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填 # 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id memory_config_id_to_use = end_user.memory_config_id
# 如果已有 memory_config_id直接使用 # 如果已有 memory_config_id直接使用
# 如果新创建enduserenduser.memory_config_id 必定为none # 如果新创建enduserenduser.memory_config_id 必定为none
# 那么使用从release中获取memory_config_id为预期行为并且回填到 # 那么使用从release中获取memory_config_id为预期行为并且回填到
# end_user.memory_config_id # end_user.memory_config_id
if not memory_config_id_to_use: if not memory_config_id_to_use:
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config") logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
# 获取最新发布版本 # 获取最新发布版本
stmt = ( stmt = (
select(AppRelease) select(AppRelease)
@@ -1208,10 +1203,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
) )
# TODO: change to current_release_id # TODO: change to current_release_id
latest_release = db.scalars(stmt).first() latest_release = db.scalars(stmt).first()
if latest_release: if latest_release:
config = latest_release.config or {} config = latest_release.config or {}
# 如果 config 是字符串,解析为字典 # 如果 config 是字符串,解析为字典
if isinstance(config, str): if isinstance(config, str):
try: try:
@@ -1219,22 +1214,24 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json_module.JSONDecodeError: except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}") logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {} config = {}
# 使用 MemoryConfigService 的提取方法 # 使用 MemoryConfigService 的提取方法
memory_config_service = MemoryConfigService(db) memory_config_service = MemoryConfigService(db)
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id( legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
app_type=app.type, app_type=app.type,
config=config config=config
) )
if legacy_config_id: if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中 # 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.models.memory_config_model import (
MemoryConfig as MemoryConfigModel,
)
existing_config = db.get(MemoryConfigModel, legacy_config_id) existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config: if existing_config:
memory_config_id_to_use = legacy_config_id memory_config_id_to_use = legacy_config_id
# 回填到 end_user 表lazy update # 回填到 end_user 表lazy update
end_user.memory_config_id = memory_config_id_to_use end_user.memory_config_id = memory_config_id_to_use
db.commit() db.commit()
@@ -1263,12 +1260,13 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
result = { result = {
"end_user_id": str(end_user_id), "end_user_id": str(end_user_id),
"app_id": str(app_id), "app_id": str(app_id),
"release_id": str(app.current_release_id), "release_id": str(app.current_release_id) if app.current_release_id else None,
"memory_config_id": memory_config_id, "memory_config_id": memory_config_id,
"workspace_id": str(app.workspace_id) "workspace_id": str(app.workspace_id)
} }
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result return result
@@ -1312,7 +1310,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id # 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建映射 - 保留 EndUser 对象引用以便回填 # 创建映射 - 保留 EndUser 对象引用以便回填
end_user_map = {str(eu.id): eu for eu in end_users} end_user_map = {str(eu.id): eu for eu in end_users}
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users} user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
@@ -1336,15 +1334,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取 # 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
users_needing_migration = [ users_needing_migration = [
(end_user_id, data["app_id"]) (end_user_id, data["app_id"])
for end_user_id, data in user_data.items() for end_user_id, data in user_data.items()
if not data["memory_config_id"] if not data["memory_config_id"]
] ]
if users_needing_migration: if users_needing_migration:
# 批量获取相关应用的最新发布版本 # 批量获取相关应用的最新发布版本
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration)) migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
# 查询每个应用的最新活跃发布版本 # 查询每个应用的最新活跃发布版本
app_latest_releases = {} app_latest_releases = {}
for app_id in migration_app_ids: for app_id in migration_app_ids:
@@ -1357,18 +1355,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
latest_release = db.scalars(stmt).first() latest_release = db.scalars(stmt).first()
if latest_release: if latest_release:
app_latest_releases[app_id] = latest_release app_latest_releases[app_id] = latest_release
# 为每个需要迁移的用户提取 memory_config_id # 为每个需要迁移的用户提取 memory_config_id
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
users_to_backfill = [] # [(end_user, memory_config_id), ...] users_to_backfill = [] # [(end_user, memory_config_id), ...]
for end_user_id, app_id in users_needing_migration: for end_user_id, app_id in users_needing_migration:
latest_release = app_latest_releases.get(app_id) latest_release = app_latest_releases.get(app_id)
if not latest_release: if not latest_release:
continue continue
config = latest_release.config or {} config = latest_release.config or {}
# 如果 config 是字符串,解析为字典 # 如果 config 是字符串,解析为字典
if isinstance(config, str): if isinstance(config, str):
try: try:
@@ -1376,21 +1374,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
except json_module.JSONDecodeError: except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}") logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
continue continue
# 使用 MemoryConfigService 的提取方法 # 使用 MemoryConfigService 的提取方法
app = app_map.get(app_id) app = app_map.get(app_id)
if not app: if not app:
continue continue
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id( legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
app_type=app.type, app_type=app.type,
config=config config=config
) )
if legacy_config_id: if legacy_config_id:
# 更新 user_data 中的 memory_config_id # 更新 user_data 中的 memory_config_id
user_data[end_user_id]["memory_config_id"] = legacy_config_id user_data[end_user_id]["memory_config_id"] = legacy_config_id
# 记录需要回填的用户(稍后验证配置存在后再回填) # 记录需要回填的用户(稍后验证配置存在后再回填)
end_user = end_user_map.get(end_user_id) end_user = end_user_map.get(end_user_id)
if end_user: if end_user:
@@ -1399,7 +1397,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
logger.info( logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default" f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
) )
# 验证提取的 config_id 是否存在于数据库中 # 验证提取的 config_id 是否存在于数据库中
if users_to_backfill: if users_to_backfill:
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill)) config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
@@ -1407,17 +1405,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
MemoryConfig.config_id.in_(config_ids_to_validate) MemoryConfig.config_id.in_(config_ids_to_validate)
).all() ).all()
valid_config_ids = {mc.config_id for mc in existing_configs} valid_config_ids = {mc.config_id for mc in existing_configs}
# 只回填存在的配置 # 只回填存在的配置
valid_backfills = [ valid_backfills = [
(eu, cid) for eu, cid in users_to_backfill (eu, cid) for eu, cid in users_to_backfill
if cid in valid_config_ids if cid in valid_config_ids
] ]
invalid_backfills = [ invalid_backfills = [
(eu, cid) for eu, cid in users_to_backfill (eu, cid) for eu, cid in users_to_backfill
if cid not in valid_config_ids if cid not in valid_config_ids
] ]
if invalid_backfills: if invalid_backfills:
invalid_ids = [str(cid) for _, cid in invalid_backfills] invalid_ids = [str(cid) for _, cid in invalid_backfills]
logger.warning( logger.warning(
@@ -1426,7 +1424,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 清除 user_data 中无效的 config_id # 清除 user_data 中无效的 config_id
for eu, cid in invalid_backfills: for eu, cid in invalid_backfills:
user_data[str(eu.id)]["memory_config_id"] = None user_data[str(eu.id)]["memory_config_id"] = None
# 批量回填 end_user.memory_config_id # 批量回填 end_user.memory_config_id
if valid_backfills: if valid_backfills:
for end_user, memory_config_id in valid_backfills: for end_user, memory_config_id in valid_backfills:
@@ -1437,7 +1435,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id # 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = [] direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...] workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items(): for end_user_id, data in user_data.items():
if data["memory_config_id"]: if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"]) direct_config_ids.append(data["memory_config_id"])
@@ -1455,7 +1453,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑) # 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {} workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users)) unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids: if unique_workspace_ids:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids: for workspace_id in unique_workspace_ids:
@@ -1466,11 +1464,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 7. 构建最终结果 # 7. 构建最终结果
for end_user_id, data in user_data.items(): for end_user_id, data in user_data.items():
memory_config = None memory_config = None
# 优先使用 end_user 直接分配的配置 # 优先使用 end_user 直接分配的配置
if data["memory_config_id"]: if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"]) memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置 # 回退到工作空间默认配置
if not memory_config: if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"]) workspace_id = app_to_workspace.get(data["app_id"])
@@ -1486,4 +1484,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs") logger.info(f"Successfully retrieved {len(result)} connected configs")
return result return result

View File

@@ -107,6 +107,40 @@ def _validate_config_id(config_id, db: Session = None):
) )
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
# 使用懒加载函数避免模块级循环导入
def _get_builtin_pruning_scenes() -> set:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
return set(SceneConfigRegistry.get_all_scenes())
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
Args:
db: 数据库会话
scene_id: 本体场景 UUID
pruning_scene: 语义剪枝场景名称
Returns:
class_name 字符串列表,或 None内置场景 / 无数据时)
"""
if not scene_id:
return None
# 内置场景走 SceneConfigRegistry不需要注入类型列表
if pruning_scene in _get_builtin_pruning_scenes():
return None
try:
from app.repositories.ontology_class_repository import OntologyClassRepository
repo = OntologyClassRepository(db)
classes = repo.get_classes_by_scene(scene_id)
names = [c.class_name for c in classes if c.class_name]
return names if names else None
except Exception as e:
logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}")
return None
class MemoryConfigService: class MemoryConfigService:
""" """
Centralized service for memory configuration loading and validation. Centralized service for memory configuration loading and validation.
@@ -359,6 +393,7 @@ class MemoryConfigService:
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association # Ontology scene association
scene_id=memory_config.scene_id, scene_id=memory_config.scene_id,
ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene),
) )
elapsed_ms = (time.time() - start_time) * 1000 elapsed_ms = (time.time() - start_time) * 1000

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件 # 修改 memory_konwledges_server.py 文件
import asyncio
import os import os
import re
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db_context
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.models.document_model import Document from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试 # 创建一个简单的用户类用于测试
api_logger = get_api_logger() api_logger = get_api_logger()
class ChunkCreate(BaseModel): class ChunkCreate(BaseModel):
content: str content: str
class SimpleUser: class SimpleUser:
def __init__(self, user_id: str): def __init__(self, user_id: str):
# 确保ID是UUID类型 # 确保ID是UUID类型
self.id = user_id self.id = user_id
self.username = user_id self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
""" """
解析指定文档 解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise raise
'''获取块ID'''
async def get_document_chunks( async def get_document_chunks(
kb_id: uuid.UUID, kb_id: uuid.UUID,
document_id: uuid.UUID, document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功") return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename( def find_document_id_by_kb_and_filename(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e: except Exception as e:
return None return None
'''获取知识库ID'''
def find_documents_by_kb_id( def find_documents_by_kb_id(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e: except Exception as e:
return [] return []
''''上传文件'''
async def memory_konwledges_up( async def memory_konwledges_up(
kb_id: str, kb_id: str,
parent_id: str, parent_id: str,
create_data: file_schema.CustomTextFileCreate, create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db), db: Session,
current_user: SimpleUser = None, # 修改为SimpleUser current_user: SimpleUser,
): ):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8') content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes) file_size = len(content_bytes)
print(f"file size: {file_size} byte") print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk( async def create_document_chunk(
kb_id: uuid.UUID, kb_id: uuid.UUID,
@@ -417,7 +408,7 @@ async def create_document_chunk(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询文档块失败: {error_msg}" detail=f"查询文档块失败: {error_msg}"
) )
sort_id = sort_id + 1 sort_id = sort_id + 1
# 5. 创建文档块 # 5. 创建文档块
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功") return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id): async def write_rag(end_user_id, message, user_rag_memory_id):
""" """
将消息写入 RAG 知识库 将消息写入 RAG 知识库
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}" detail=f"知识库ID格式无效: {user_rag_memory_id}"
) )
db_gen = get_db() with get_db_context() as db:
db = next(db_gen)
try:
create_data = CustomTextFileCreate(title=end_user_id, content=message) create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id) current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在 # 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document) print('======', document)
api_logger.info(f"查找文档结果: document_id={document}") api_logger.info(f"查找文档结果: document_id={document}")
if document is not None: if document is not None:
# 文档已存在,直接添加新块 # 文档已存在,直接添加新块
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else: else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Create --- # --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig
existing = (
self.db.query(MemoryConfig)
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
.first()
)
if existing:
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
# 如果workspace_id存在且模型字段未全部指定则自动获取 # 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]): if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id) configs = self._get_workspace_configs(params.workspace_id)
@@ -135,6 +146,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.emotion_model_id: if not params.emotion_model_id:
params.emotion_model_id = params.llm_id params.emotion_model_id = params.llm_id
# 根据关联的本体场景推导 pruning_scene语义剪枝场景与本体工程场景保持一致
if params.scene_id and not getattr(params, 'pruning_scene', None):
params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id)
config = MemoryConfigRepository.create(self.db, params) config = MemoryConfigRepository.create(self.db, params)
self.db.commit() self.db.commit()
return {"affected": 1, "config_id": config.config_id} return {"affected": 1, "config_id": config.config_id}
@@ -150,6 +165,23 @@ class DataConfigService: # 数据配置服务类PostgreSQL
finally: finally:
db_session.close() db_session.close()
def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]:
"""根据本体场景ID获取对应的 scene_name作为语义剪枝场景值
Args:
scene_id: 本体场景UUID
Returns:
scene_name 字符串,查询失败时返回 None
"""
try:
from app.models.ontology_scene import OntologyScene
scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first()
return scene.scene_name if scene else None
except Exception as e:
logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True)
return None
# --- Delete --- # --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数按配置ID
success = MemoryConfigRepository.delete(self.db, key.config_id) success = MemoryConfigRepository.delete(self.db, key.config_id)
@@ -185,6 +217,19 @@ class DataConfigService: # 数据配置服务类PostgreSQL
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
results = MemoryConfigRepository.get_all(self.db, workspace_id) results = MemoryConfigRepository.get_all(self.db, workspace_id)
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
needs_commit = False
for config, scene_name in results:
if scene_name and config.pruning_scene != scene_name:
logger.info(
f"修正 pruning_scene: config_id={config.config_id} "
f"'{config.pruning_scene}' -> '{scene_name}'"
)
config.pruning_scene = scene_name
needs_commit = True
if needs_commit:
self.db.commit()
# 将 ORM 对象转换为字典列表 # 将 ORM 对象转换为字典列表
data_list = [] data_list = []
for config, scene_name in results: for config, scene_name in results:
@@ -211,6 +256,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"apply_id": config.apply_id, "apply_id": config.apply_id,
"scene_id": str(config.scene_id) if config.scene_id else None, "scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称 "scene_name": scene_name, # 新增:场景名称
"is_system_default": config.is_default, # 是否为系统默认配置
"llm_id": config.llm_id, "llm_id": config.llm_id,
"embedding_id": config.embedding_id, "embedding_id": config.embedding_id,
"rerank_id": config.rerank_id, "rerank_id": config.rerank_id,
@@ -737,8 +783,37 @@ async def analytics_hot_memory_tags(
await connector.close() await connector.close()
async def analytics_recent_activity_stats() -> Dict[str, Any]: async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]:
stats, _msg = get_recent_activity_stats() """获取最近记忆提取活动统计。
优先从 Redis 缓存读取(按 workspace_id缓存不存在时降级到日志文件解析。
Args:
workspace_id: 工作空间ID用于从 Redis 读取对应缓存
Returns:
包含 total、stats、latest_relative、source 的统计字典
"""
stats = None
source = "log"
# 优先从 Redis 读取
if workspace_id:
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
cached = await ActivityStatsCache.get_activity_stats(workspace_id)
if cached:
stats = cached.get("stats", {})
source = "redis"
logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}")
except Exception as e:
logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}")
# 降级:从日志文件解析
if stats is None:
stats, _msg = get_recent_activity_stats()
source = "log"
total = ( total = (
stats.get("chunk_count", 0) stats.get("chunk_count", 0)
+ stats.get("statements_count", 0) + stats.get("statements_count", 0)
@@ -746,26 +821,29 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
+ stats.get("triplet_relations_count", 0) + stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0) + stats.get("temporal_count", 0)
) )
# 精简:仅提供“最新一次活动多久前”
latest_relative = None
try:
info = stats.get("log_path", "")
idx = info.rfind("最新:")
if idx != -1:
latest_path = info[idx + 3 :].strip()
if latest_path and os.path.exists(latest_path):
import time
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
m = int(diff // 60)
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = "一会前"
else:
latest_relative = "较早前"
except Exception:
pass
data = {"total": total, "stats": stats, "latest_relative": latest_relative} # 计算"最新一次活动多久前"(仅日志来源时有效)
latest_relative = None
if source == "log":
try:
info = stats.get("log_path", "")
idx = info.rfind("最新:")
if idx != -1:
latest_path = info[idx + 3:].strip()
if latest_path and os.path.exists(latest_path):
import time
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
m = int(diff // 60)
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = "一会前"
else:
latest_relative = "较早前"
except Exception:
pass
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
return data return data

View File

@@ -116,27 +116,15 @@ class ModelConfigService:
try: try:
start_time = time.time() start_time = time.time()
# dashscope 的 omni 模型需要使用 compatible-mode model_config = RedBearModelConfig(
if provider.lower() == ModelProvider.DASHSCOPE and is_omni: model_name=model_name,
if not api_base: provider=provider,
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" api_key=api_key,
model_config = RedBearModelConfig( base_url=api_base,
model_name=model_name, is_omni=is_omni,
provider=ModelProvider.OPENAI, temperature=0.7,
api_key=api_key, max_tokens=100
base_url=api_base, )
temperature=0.7,
max_tokens=100
)
else:
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式 # 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower() model_type_lower = model_type.lower()
@@ -492,6 +480,9 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
continue continue
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name # 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name model_name = model_config.model_base.name if model_config.model_base else model_config.name
@@ -550,8 +541,8 @@ class ModelApiKeyService:
provider=data.provider, provider=data.provider,
api_key=data.api_key, api_key=data.api_key,
api_base=data.api_base, api_base=data.api_base,
capability=data.capability if data.capability is not None else model_config.capability, capability=data.capability,
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni, is_omni=data.is_omni,
config=data.config, config=data.config,
is_active=data.is_active, is_active=data.is_active,
priority=data.priority priority=data.priority
@@ -574,6 +565,10 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if api_key_data.is_omni is None:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id # 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
@@ -616,7 +611,7 @@ class ModelApiKeyService:
api_base=api_key_data.api_base, api_base=api_key_data.api_base,
model_type=model_config.type, model_type=model_config.type,
test_message="Hello", test_message="Hello",
is_omni=model_config.is_omni is_omni=api_key_data.is_omni
) )
if not validation_result["valid"]: if not validation_result["valid"]:
raise BusinessException( raise BusinessException(
@@ -785,6 +780,7 @@ class ModelBaseService:
"description": model_base.description, "description": model_base.description,
"capability": model_base.capability, "capability": model_base.capability,
"is_omni": model_base.is_omni, "is_omni": model_base.is_omni,
"is_active": False,
"is_composite": False "is_composite": False
} }
model_config = ModelConfigRepository.create(db, model_config_data) model_config = ModelConfigRepository.create(db, model_config_data)

View File

@@ -326,6 +326,25 @@ async def run_pilot_extraction(
logger.info("Pilot run completed: Skipping Neo4j save") logger.info("Pilot run completed: Skipping Neo4j save")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
"statements_count": len(statement_nodes) if statement_nodes else 0,
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
"temporal_count": 0, # temporal 数据在日志中此处暂置0
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
except Exception as e: except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True) logger.error(f"Pilot run failed: {e}", exc_info=True)
raise raise

View File

@@ -8,6 +8,8 @@ from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
from app.repositories.tool_repository import ( from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository, ToolRepository, BuiltinToolRepository, CustomToolRepository,
@@ -79,6 +81,18 @@ class ToolService:
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
return self._config_to_info(config) if config else None return self._config_to_info(config) if config else None
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
"""检查工具名称是否重复"""
query = self.db.query(ToolConfig).filter(
ToolConfig.name == name,
ToolConfig.tool_type == tool_type,
ToolConfig.tenant_id == tenant_id
)
if exclude_id:
query = query.filter(ToolConfig.id != exclude_id)
if query.first():
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
def create_tool( def create_tool(
self, self,
name: str, name: str,
@@ -92,6 +106,7 @@ class ToolService:
"""创建工具""" """创建工具"""
if tool_type == ToolType.BUILTIN: if tool_type == ToolType.BUILTIN:
raise ValueError("内置工具不允许创建") raise ValueError("内置工具不允许创建")
self._check_name_duplicate(name, tool_type, tenant_id)
try: try:
# 创建基础配置 # 创建基础配置
@@ -141,6 +156,7 @@ class ToolService:
raise ValueError("内置工具不允许修改名称、描述和图标") raise ValueError("内置工具不允许修改名称、描述和图标")
try: try:
if name: if name:
self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id)
config_obj.name = name config_obj.name = name
if description: if description:
config_obj.description = description config_obj.description = description
@@ -894,7 +910,11 @@ class ToolService:
config_data.update({ config_data.update({
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None, "last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
"health_status": mcp_config.health_status, "health_status": mcp_config.health_status,
"available_tools": available_tools_display "available_tools": available_tools_display,
"source_channel": mcp_config.source_channel,
"market_id": mcp_config.market_id,
"market_config_id": mcp_config.market_config_id,
"mcp_service_id": mcp_config.mcp_service_id
}) })
return ToolInfo( return ToolInfo(
@@ -949,7 +969,11 @@ class ToolService:
id=tool_config.id, id=tool_config.id,
server_url=config.get("server_url"), server_url=config.get("server_url"),
connection_config=config.get("connection_config", {}), connection_config=config.get("connection_config", {}),
available_tools=config.get("available_tools", []) available_tools=config.get("available_tools", []),
source_channel=config.get("source_channel", "self_hosted"),
market_id=config.get("market_id"),
market_config_id=config.get("market_config_id"),
mcp_service_id=config.get("mcp_service_id"),
) )
self.db.add(mcp_config) self.db.add(mcp_config)
@@ -1002,6 +1026,14 @@ class ToolService:
mcp_config.server_url = config.get("server_url") mcp_config.server_url = config.get("server_url")
mcp_config.connection_config = config.get("connection_config", {}) mcp_config.connection_config = config.get("connection_config", {})
mcp_config.available_tools = config.get("available_tools", []) mcp_config.available_tools = config.get("available_tools", [])
if config.get("source_channel") is not None:
mcp_config.source_channel = config.get("source_channel")
if config.get("market_id") is not None:
mcp_config.market_id = config.get("market_id")
if config.get("market_config_id") is not None:
mcp_config.market_config_id = config.get("market_config_id")
if config.get("mcp_service_id") is not None:
mcp_config.mcp_service_id = config.get("mcp_service_id")
@staticmethod @staticmethod
def _determine_initial_status(tool_info: Dict[str, Any]) -> str: def _determine_initial_status(tool_info: Dict[str, Any]) -> str:

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.memory_base_service import MemoryBaseService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数 # 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id: if end_user_id:
try: try:
# 获取数据库会话并查询用户信息 # 获取数据库会话并查询用户信息
db = next(get_db()) with get_db_context() as db:
try:
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id)) end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name: if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else: else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e: except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")

View File

@@ -56,7 +56,7 @@ class WorkflowImportService:
success=False, success=False,
temp_id=None, temp_id=None,
workflow_id=None, workflow_id=None,
errors=[InvalidConfiguration()] errors=[InvalidConfiguration()] + adapter.errors
) )
workflow_config = adapter.parse_workflow() workflow_config = adapter.parse_workflow()

View File

@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository, WorkflowExecutionRepository,
WorkflowNodeExecutionRepository WorkflowNodeExecutionRepository
) )
from app.schemas import DraftRunRequest, FileInput from app.schemas import DraftRunRequest, FileInput, FileType
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
@@ -496,6 +496,7 @@ class WorkflowService:
"event": "start", "event": "start",
"data": { "data": {
"conversation_id": payload.get("conversation_id"), "conversation_id": payload.get("conversation_id"),
"message_id": payload.get("message_id")
} }
} }
case "workflow_end": case "workflow_end":
@@ -600,6 +601,7 @@ class WorkflowService:
try: try:
files = await self._handle_file_input(payload.files) files = await self._handle_file_input(payload.files)
input_data["files"] = files input_data["files"] = files
message_id = uuid.uuid4()
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
@@ -624,24 +626,45 @@ class WorkflowService:
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id user_id=payload.user_id
) )
# 更新执行结果 # 更新执行结果
if result.get("status") == "completed": if result.get("status") == "completed":
token_usage = result.get("token_usage", {}) or {} token_usage = result.get("token_usage", {}) or {}
final_messages = result.get("messages", [])[init_message_length:]
human_message = ""
assistant_message = ""
for message in final_messages:
if message["role"] == "user":
if isinstance(message["content"], str):
human_message += message["content"]
elif isinstance(message["content"], list):
for file in message["content"]:
if file.get("type") == FileType.IMAGE:
human_message += f"![image]({file.get('url', '')})"
else:
human_message += f"[{file.get('type')}]({file.get('url', '')})"
if message["role"] == "assistant":
assistant_message = message["content"]
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role="user",
content=human_message,
meta_data=None
)
self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id_uuid,
role="assistant",
content=assistant_message,
meta_data={"usage": token_usage}
)
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
"completed", "completed",
output_data=result, output_data=result,
token_usage=token_usage.get("total_tokens", None) token_usage=token_usage.get("total_tokens", None)
) )
final_messages = result.get("messages", [])[init_message_length:]
for message in final_messages:
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
meta_data=None if message["role"] == "user" else {"usage": token_usage}
)
logger.info(f"Workflow Run Success, " logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
else: else:
@@ -650,6 +673,8 @@ class WorkflowService:
"failed", "failed",
error_message=result.get("error") error_message=result.get("error")
) )
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
f" error: {result.get('error')}")
# 返回增强的响应结构 # 返回增强的响应结构
return { return {
@@ -659,6 +684,7 @@ class WorkflowService:
# "messages": result.get("messages"), # "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串) "output": result.get("output"), # 最终输出(字符串)
"message": result.get("output"), # 最终输出(字符串) "message": result.get("output"), # 最终输出(字符串)
"message_id": str(message_id),
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"), "error_message": result.get("error"),
@@ -756,7 +782,7 @@ class WorkflowService:
input_data["conv_messages"] = last_state.get("messages") or [] input_data["conv_messages"] = last_state.get("messages") or []
break break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4()
async for event in execute_workflow_stream( async for event in execute_workflow_stream(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
@@ -765,24 +791,43 @@ class WorkflowService:
user_id=payload.user_id, user_id=payload.user_id,
): ):
if event.get("event") == "workflow_end": if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status") status = event.get("data", {}).get("status")
token_usage = event.get("data", {}).get("token_usage", {}) or {} token_usage = event.get("data", {}).get("token_usage", {}) or {}
if status == "completed": if status == "completed":
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
human_message = ""
assistant_message = ""
for message in final_messages:
if message["role"] == "user":
if isinstance(message["content"], str):
human_message += message["content"]
elif isinstance(message["content"], list):
for file in message["content"]:
if file.get("type") == FileType.IMAGE:
human_message += f"![image]({file.get('url', '')})"
else:
human_message += f"[{file.get('type')}]({file.get('url', '')})"
if message["role"] == "assistant":
assistant_message = message["content"]
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role="user",
content=human_message,
meta_data=None
)
self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id_uuid,
role="assistant",
content=assistant_message,
meta_data={"usage": token_usage}
)
self.update_execution_status( self.update_execution_status(
execution.execution_id, execution.execution_id,
"completed", "completed",
output_data=event.get("data"), output_data=event.get("data"),
token_usage=token_usage.get("total_tokens", None) token_usage=token_usage.get("total_tokens", None)
) )
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
for message in final_messages:
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"],
meta_data=None if message["role"] == "user" else {"usage": token_usage}
)
logger.info(f"Workflow Run Success, " logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
elif status == "failed": elif status == "failed":
@@ -793,6 +838,8 @@ class WorkflowService:
) )
else: else:
logger.error(f"unexpect workflow run status, status: {status}") logger.error(f"unexpect workflow run status, status: {status}")
elif event.get("event") == "workflow_start":
event["data"]["message_id"] = str(message_id)
event = self._emit(public, event) event = self._emit(public, event)
if event: if event:
yield event yield event

View File

@@ -2,11 +2,11 @@ import datetime
import hashlib import hashlib
import secrets import secrets
import uuid import uuid
from os import getenv
from typing import List, Optional from typing import List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.config.default_ontology_initializer import DefaultOntologyInitializer
from app.core.config import settings from app.core.config import settings
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException from app.core.exceptions import BusinessException, PermissionDeniedException
@@ -30,17 +30,15 @@ from app.schemas.workspace_schema import (
WorkspaceModelsUpdate, WorkspaceModelsUpdate,
WorkspaceUpdate, WorkspaceUpdate,
) )
from app.config.default_ontology_initializer import DefaultOntologyInitializer
# 获取业务逻辑专用日志器 # 获取业务逻辑专用日志器
business_logger = get_business_logger() business_logger = get_business_logger()
from dotenv import load_dotenv
load_dotenv()
def switch_workspace( def switch_workspace(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
user: User, user: User,
): ):
"""切换工作空间""" """切换工作空间"""
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}") business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
@@ -60,31 +58,32 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR) raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
def delete_workspace_member( def delete_workspace_member(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
member_id: uuid.UUID, member_id: uuid.UUID,
user: User, user: User,
): ):
"""删除工作空间成员""" """删除工作空间成员"""
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
_check_workspace_admin_permission(db, workspace_id, user) _check_workspace_admin_permission(db, workspace_id, user)
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id) workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
if not workspace_member: if not workspace_member:
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND) raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
if workspace_member.workspace_id != workspace_id: if workspace_member.workspace_id != workspace_id:
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND) raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}",
BizCode.WORKSPACE_NOT_FOUND)
try: try:
workspace_member.is_active = False workspace_member.is_active = False
workspace_member.user.current_workspace_id = None workspace_member.user.current_workspace_id = None
db.commit() db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}") business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
except Exception as e: except Exception as e:
db.rollback() db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR) raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
def get_user_workspaces(db: Session, user: User) -> List[Workspace]: def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
@@ -102,18 +101,19 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
""" """
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})") business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id) workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
# Ensure each neo4j workspace has a default memory config # Ensure each neo4j workspace has a default memory config
for workspace in workspaces: for workspace in workspaces:
if workspace.storage_type == 'neo4j': if workspace.storage_type == 'neo4j':
_ensure_default_memory_config(db, workspace) _ensure_default_memory_config(db, workspace)
_ensure_default_ontology_scenes(db, workspace)
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}") business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
return workspaces return workspaces
def _create_workspace_only( def _create_workspace_only(
db: Session, workspace: WorkspaceCreate, owner: User db: Session, workspace: WorkspaceCreate, owner: User
) -> Workspace: ) -> Workspace:
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}") business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
@@ -129,6 +129,7 @@ def _create_workspace_only(
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}") business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
raise raise
def create_workspace( def create_workspace(
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh" db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
) -> Workspace: ) -> Workspace:
@@ -136,9 +137,14 @@ def create_workspace(
f"创建工作空间: {workspace.name}, 创建者: {user.username}, " f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
f"storage_type: {workspace.storage_type}" f"storage_type: {workspace.storage_type}"
) )
llm=workspace.llm if workspace_repository.get_workspaces_by_name(db=db, name=workspace.name, tenant_id=user.tenant_id):
embedding=workspace.embedding raise BusinessException(
rerank=workspace.rerank message="同名工作空间已存在",
code=BizCode.RESOURCE_ALREADY_EXISTS
)
llm = workspace.llm
embedding = workspace.embedding
rerank = workspace.rerank
try: try:
# Create the workspace without adding any members # Create the workspace without adding any members
business_logger.debug(f"创建工作空间: {workspace.name}") business_logger.debug(f"创建工作空间: {workspace.name}")
@@ -151,33 +157,35 @@ def create_workspace(
# Initialize default ontology scenes for the workspace (先创建本体场景) # Initialize default ontology scenes for the workspace (先创建本体场景)
default_scene_id = None default_scene_id = None
default_scene_name = None
try: try:
initializer = DefaultOntologyInitializer(db) initializer = DefaultOntologyInitializer(db)
success, error_msg = initializer.initialize_default_scenes( success, error_msg = initializer.initialize_default_scenes(
db_workspace.id, language=language db_workspace.id, language=language
) )
if success: if success:
business_logger.info( business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})" f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
) )
# 获取默认场景ID优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景 # 获取默认场景ID优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
from app.repositories.ontology_scene_repository import OntologySceneRepository from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import ( from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE, ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE, EMOTIONAL_COMPANION_SCENE,
get_scene_name get_scene_name
) )
scene_repo = OntologySceneRepository(db) scene_repo = OntologySceneRepository(db)
# 优先尝试获取教育场景 # 优先尝试获取教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language) education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id) education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
if education_scene: if education_scene:
default_scene_id = education_scene.scene_id default_scene_id = education_scene.scene_id
default_scene_name = education_scene.scene_name
business_logger.info( business_logger.info(
f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})" f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})"
) )
@@ -185,9 +193,10 @@ def create_workspace(
# 如果教育场景不存在,尝试获取情感陪伴场景 # 如果教育场景不存在,尝试获取情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language) companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id) companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
if companion_scene: if companion_scene:
default_scene_id = companion_scene.scene_id default_scene_id = companion_scene.scene_id
default_scene_name = companion_scene.scene_name
business_logger.info( business_logger.info(
f"教育场景不存在使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})" f"教育场景不存在使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})"
) )
@@ -218,6 +227,7 @@ def create_workspace(
embedding_id=embedding, embedding_id=embedding,
rerank_id=rerank, rerank_id=rerank,
scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景 scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景
pruning_scene_name=default_scene_name, # 传入场景名称作为语义剪枝场景值
) )
business_logger.info( business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})" f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})"
@@ -250,10 +260,10 @@ def create_workspace(
avatar='', avatar='',
type=KnowledgeType.General, type=KnowledgeType.General,
permission_id=PermissionType.Memory, permission_id=PermissionType.Memory,
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding, embedding_id=embedding,
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank, reranker_id=rerank,
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm, llm_id=llm,
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm, image2text_id=llm,
parser_config={ parser_config={
"layout_recognize": "DeepDOC", "layout_recognize": "DeepDOC",
"chunk_token_num": 256, "chunk_token_num": 256,
@@ -288,7 +298,7 @@ def create_workspace(
business_logger.info( business_logger.info(
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交" f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
) )
return db_workspace return db_workspace
except Exception as e: except Exception as e:
@@ -298,11 +308,11 @@ def create_workspace(
def update_workspace( def update_workspace(
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
) -> Workspace: ) -> Workspace:
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}") business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
db_workspace = _check_workspace_admin_permission(db,workspace_id,user) db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
try: try:
# 更新工作空间 # 更新工作空间
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})") business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
@@ -322,7 +332,7 @@ def update_workspace(
def get_workspace_members( def get_workspace_members(
db: Session, workspace_id: uuid.UUID, user: User db: Session, workspace_id: uuid.UUID, user: User
) -> List[WorkspaceMember]: ) -> List[WorkspaceMember]:
"""获取某工作空间的成员列表(关系序列化由模型关系支持)""" """获取某工作空间的成员列表(关系序列化由模型关系支持)"""
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}") business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
@@ -366,7 +376,6 @@ def get_workspace_members(
return members return members
# ==================== 邀请相关服务方法 ==================== # ==================== 邀请相关服务方法 ====================
def _generate_invite_token() -> tuple[str, str]: def _generate_invite_token() -> tuple[str, str]:
@@ -459,13 +468,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
def create_workspace_invite( def create_workspace_invite(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate, invite_data: WorkspaceInviteCreate,
user: User user: User
) -> WorkspaceInviteResponse: ) -> WorkspaceInviteResponse:
"""创建工作空间邀请""" """创建工作空间邀请"""
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}") business_logger.info(
f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
try: try:
# 检查权限 # 检查权限
@@ -528,17 +538,18 @@ def create_workspace_invite(
except Exception as e: except Exception as e:
db.rollback() db.rollback()
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}") business_logger.error(
f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
raise raise
def get_workspace_invites( def get_workspace_invites(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
user: User, user: User,
status: Optional[InviteStatus] = None, status: Optional[InviteStatus] = None,
limit: int = 50, limit: int = 50,
offset: int = 0 offset: int = 0
) -> List[WorkspaceInviteResponse]: ) -> List[WorkspaceInviteResponse]:
"""获取工作空间邀请列表""" """获取工作空间邀请列表"""
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}") business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
@@ -599,9 +610,9 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
def accept_workspace_invite( def accept_workspace_invite(
db: Session, db: Session,
accept_request: InviteAcceptRequest, accept_request: InviteAcceptRequest,
user: User user: User
) -> dict: ) -> dict:
"""接受工作空间邀请""" """接受工作空间邀请"""
business_logger.info(f"接受工作空间邀请: 用户 {user.username}") business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
@@ -689,7 +700,8 @@ def accept_workspace_invite(
# 获取工作空间信息 # 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}") business_logger.info(
f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
return { return {
"message": "Successfully joined the workspace", "message": "Successfully joined the workspace",
@@ -704,13 +716,14 @@ def accept_workspace_invite(
def revoke_workspace_invite( def revoke_workspace_invite(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
invite_id: uuid.UUID, invite_id: uuid.UUID,
user: User user: User
) -> dict: ) -> dict:
"""撤销工作空间邀请""" """撤销工作空间邀请"""
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}") business_logger.info(
f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
try: try:
# 检查权限 # 检查权限
@@ -739,13 +752,14 @@ def revoke_workspace_invite(
def update_workspace_member_roles( def update_workspace_member_roles(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
updates: List[WorkspaceMemberUpdate], updates: List[WorkspaceMemberUpdate],
user: User, user: User,
) -> List[WorkspaceMember]: ) -> List[WorkspaceMember]:
"""更新工作空间成员角色""" """更新工作空间成员角色"""
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}") business_logger.info(
f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
# 检查管理员权限 # 检查管理员权限
_check_workspace_admin_permission(db, workspace_id, user) _check_workspace_admin_permission(db, workspace_id, user)
@@ -759,7 +773,8 @@ def update_workspace_member_roles(
for upd in updates: for upd in updates:
# 检查成员是否存在 # 检查成员是否存在
if upd.id not in member_map: if upd.id not in member_map:
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}",
BizCode.WORKSPACE_MEMBER_NOT_FOUND)
member = member_map[upd.id] member = member_map[upd.id]
@@ -911,10 +926,10 @@ def get_workspace_models_configs(
def update_workspace_models_configs( def update_workspace_models_configs(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
models_update: WorkspaceModelsUpdate, models_update: WorkspaceModelsUpdate,
user: User, user: User,
) -> Workspace: ) -> Workspace:
"""更新工作空间的模型配置llm, embedding, rerank """更新工作空间的模型配置llm, embedding, rerank
@@ -961,88 +976,9 @@ def update_workspace_models_configs(
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR) raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
"""Ensure a workspace has a default memory config, creating one if missing.
Also fills empty model fields for all configs in this workspace.
Args:
db: Database session
workspace: The workspace to check
"""
from app.models.memory_config_model import MemoryConfig
# Check if default config exists for this workspace
existing_default = db.query(MemoryConfig).filter(
MemoryConfig.workspace_id == workspace.id,
MemoryConfig.is_default == True
).first()
if not existing_default:
# No default config exists, create one
business_logger.info(
f"Workspace {workspace.id} missing default memory config, creating one"
)
# 尝试获取默认场景ID优先教育场景其次情感陪伴场景
default_scene_id = None
try:
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 尝试中文和英文场景名称
for language in ["zh", "en"]:
# 优先尝试教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, workspace.id)
if education_scene:
default_scene_id = education_scene.scene_id
business_logger.info(
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
)
break
# 如果教育场景不存在,尝试情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
if companion_scene:
default_scene_id = companion_scene.scene_id
business_logger.info(
f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}"
)
break
except Exception as scene_error:
business_logger.warning(
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
)
try:
_create_default_memory_config(
db=db,
workspace_id=workspace.id,
workspace_name=workspace.name,
llm_id=uuid.UUID(workspace.llm) if workspace.llm else None,
embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None,
rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None,
scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景
)
except Exception as e:
business_logger.error(
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
)
# Fill empty model fields for ALL configs in this workspace
_fill_workspace_configs_model_defaults(db, workspace)
def _fill_workspace_configs_model_defaults( def _fill_workspace_configs_model_defaults(
db: Session, db: Session,
workspace: Workspace workspace: Workspace
) -> None: ) -> None:
"""Fill empty model fields for all memory configs in a workspace. """Fill empty model fields for all memory configs in a workspace.
@@ -1054,43 +990,43 @@ def _fill_workspace_configs_model_defaults(
workspace: The workspace containing default model settings workspace: The workspace containing default model settings
""" """
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
# Get all configs for this workspace # Get all configs for this workspace
configs = db.query(MemoryConfig).filter( configs = db.query(MemoryConfig).filter(
MemoryConfig.workspace_id == workspace.id MemoryConfig.workspace_id == workspace.id
).all() ).all()
if not configs: if not configs:
return return
# Map of memory_config field -> workspace field # Map of memory_config field -> workspace field
model_field_mappings = [ model_field_mappings = [
("llm_id", "llm"), ("llm_id", "llm"),
("embedding_id", "embedding"), ("embedding_id", "embedding"),
("rerank_id", "rerank"), ("rerank_id", "rerank"),
("reflection_model_id", "llm"), # reflection uses LLM ("reflection_model_id", "llm"), # reflection uses LLM
("emotion_model_id", "llm"), # emotion uses LLM ("emotion_model_id", "llm"), # emotion uses LLM
] ]
configs_updated = 0 configs_updated = 0
for memory_config in configs: for memory_config in configs:
updated_fields = [] updated_fields = []
for config_field, workspace_field in model_field_mappings: for config_field, workspace_field in model_field_mappings:
config_value = getattr(memory_config, config_field, None) config_value = getattr(memory_config, config_field, None)
workspace_value = getattr(workspace, workspace_field, None) workspace_value = getattr(workspace, workspace_field, None)
if not config_value and workspace_value: if not config_value and workspace_value:
setattr(memory_config, config_field, workspace_value) setattr(memory_config, config_field, workspace_value)
updated_fields.append(config_field) updated_fields.append(config_field)
if updated_fields: if updated_fields:
configs_updated += 1 configs_updated += 1
business_logger.debug( business_logger.debug(
f"Updated memory config {memory_config.config_id} fields: {updated_fields}" f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
) )
if configs_updated > 0: if configs_updated > 0:
try: try:
db.commit() db.commit()
@@ -1105,13 +1041,14 @@ def _fill_workspace_configs_model_defaults(
def _create_default_memory_config( def _create_default_memory_config(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
workspace_name: str, workspace_name: str,
llm_id: Optional[uuid.UUID] = None, llm_id: Optional[uuid.UUID] = None,
embedding_id: Optional[uuid.UUID] = None, embedding_id: Optional[uuid.UUID] = None,
rerank_id: Optional[uuid.UUID] = None, rerank_id: Optional[uuid.UUID] = None,
scene_id: Optional[uuid.UUID] = None, scene_id: Optional[uuid.UUID] = None,
pruning_scene_name: Optional[str] = None,
) -> None: ) -> None:
"""Create a default memory config for a newly created workspace. """Create a default memory config for a newly created workspace.
@@ -1123,11 +1060,12 @@ def _create_default_memory_config(
embedding_id: Optional embedding model ID embedding_id: Optional embedding model ID
rerank_id: Optional rerank model ID rerank_id: Optional rerank model ID
scene_id: Optional ontology scene ID (默认关联教育场景) scene_id: Optional ontology scene ID (默认关联教育场景)
pruning_scene_name: Optional pruning scene name取自 ontology_scene.scene_name
""" """
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
config_id = uuid.uuid4() config_id = uuid.uuid4()
default_config = MemoryConfig( default_config = MemoryConfig(
config_id=config_id, config_id=config_id,
config_name=f"{workspace_name} 默认配置", config_name=f"{workspace_name} 默认配置",
@@ -1136,14 +1074,15 @@ def _create_default_memory_config(
llm_id=str(llm_id) if llm_id else None, llm_id=str(llm_id) if llm_id else None,
embedding_id=str(embedding_id) if embedding_id else None, embedding_id=str(embedding_id) if embedding_id else None,
rerank_id=str(rerank_id) if rerank_id else None, rerank_id=str(rerank_id) if rerank_id else None,
scene_id=scene_id, # 关联本体场景ID scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景)
pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name
state=True, # Active by default state=True, # Active by default
is_default=True, # Mark as workspace default is_default=True, # Mark as workspace default
) )
db.add(default_config) db.add(default_config)
db.flush() # 使用 flush 而不是 commit让调用者统一提交 db.flush() # 使用 flush 而不是 commit让调用者统一提交
business_logger.info( business_logger.info(
"Created default memory config for workspace", "Created default memory config for workspace",
extra={ extra={
@@ -1153,3 +1092,130 @@ def _create_default_memory_config(
"scene_id": str(scene_id) if scene_id else None, "scene_id": str(scene_id) if scene_id else None,
} }
) )
# ==================== 检查配置相关服务 ====================
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
"""Ensure a workspace has a default memory config, creating one if missing.
Also fills empty model fields for all configs in this workspace.
Args:
db: Database session
workspace: The workspace to check
"""
from app.models.memory_config_model import MemoryConfig
# Check if default config exists for this workspace
existing_default = db.query(MemoryConfig).filter(
MemoryConfig.workspace_id == workspace.id,
MemoryConfig.is_default == True
).first()
if not existing_default:
# No default config exists, create one
business_logger.info(
f"Workspace {workspace.id} missing default memory config, creating one"
)
# 尝试获取默认场景ID优先教育场景其次情感陪伴场景
default_scene_id = None
try:
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 尝试中文和英文场景名称
for language in ["zh", "en"]:
# 优先尝试教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, workspace.id)
if education_scene:
default_scene_id = education_scene.scene_id
business_logger.info(
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
)
break
# 如果教育场景不存在,尝试情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
if companion_scene:
default_scene_id = companion_scene.scene_id
business_logger.info(
f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}"
)
break
except Exception as scene_error:
business_logger.warning(
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
)
try:
_create_default_memory_config(
db=db,
workspace_id=workspace.id,
workspace_name=workspace.name,
llm_id=uuid.UUID(workspace.llm) if workspace.llm else None,
embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None,
rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None,
scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景
)
except Exception as e:
business_logger.error(
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
)
# Fill empty model fields for ALL configs in this workspace
_fill_workspace_configs_model_defaults(db, workspace)
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
"""Ensure a workspace has default ontology scenes, creating them if missing.
Checks whether any is_system_default scene exists for the workspace.
If not, runs the DefaultOntologyInitializer to create them.
Args:
db: Database session
workspace: The workspace to check
"""
from app.models.ontology_scene import OntologyScene
# 幂等检查:是否已存在系统默认场景
existing = db.query(OntologyScene).filter(
OntologyScene.workspace_id == workspace.id,
OntologyScene.is_system_default.is_(True)
).first()
if existing:
return
business_logger.info(
f"Workspace {workspace.id} missing default ontology scenes, creating them"
)
try:
initializer = DefaultOntologyInitializer(db)
success, error_msg = initializer.initialize_default_scenes(
workspace.id, language="zh"
)
if success:
db.commit()
business_logger.info(
f"为工作空间 {workspace.id} 补建默认本体场景成功"
)
else:
business_logger.warning(
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
)
except Exception as e:
db.rollback()
business_logger.error(
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
)

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import json import json
import logging
import os import os
import re import re
import shutil import shutil
@@ -14,6 +15,62 @@ from uuid import UUID
import redis import redis
import requests import requests
from redis.exceptions import RedisError
logger = logging.getLogger(__name__)
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
# 连接 CELERY_BACKEND DB与 write_message:last_done 时间戳写入保持一致
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
_sync_redis_pool: redis.ConnectionPool = None
def _get_or_create_redis_pool() -> redis.ConnectionPool:
"""获取或创建 Redis 连接池(懒初始化)"""
global _sync_redis_pool
if _sync_redis_pool is None:
try:
_sync_redis_pool = redis.ConnectionPool(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB_CELERY_BACKEND,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=10,
socket_connect_timeout=5,
socket_timeout=5,
retry_on_timeout=True,
health_check_interval=30,
)
logger.info("Redis connection pool created for Celery tasks")
except Exception as e:
logger.error(f"Failed to create Redis connection pool: {e}", exc_info=True)
return None
return _sync_redis_pool
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
"""获取同步 Redis 客户端(使用连接池)
使用连接池提供的客户端,支持自动重连和健康检查。
如果 Redis 不可用,返回 None调用方应优雅降级。
Returns:
redis.StrictRedis: Redis 客户端实例,如果连接失败则返回 None
"""
try:
pool = _get_or_create_redis_pool()
if pool is None:
return None
client = redis.StrictRedis(connection_pool=pool)
# 验证连接可用性
client.ping()
return client
except RedisError as e:
logger.error(f"Redis connection failed: {e}", exc_info=True)
return None
except Exception as e:
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
return None
# Import a unified Celery instance # Import a unified Celery instance
from app.celery_app import celery_app from app.celery_app import celery_app
@@ -1090,6 +1147,22 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
logger.info( logger.info(
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try:
_r = get_sync_redis_client()
if _r is not None:
from datetime import timedelta as _td
from datetime import timezone as _tz
_CST = _tz(_td(hours=8))
_now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat()
_r.set(
f"write_message:last_done:{end_user_id}",
_now_cst,
ex=86400 * 30,
)
except Exception as _e:
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"result": result, "result": result,
@@ -2149,12 +2222,16 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from sqlalchemy import func, select
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from sqlalchemy import select, func from app.repositories.implicit_emotions_storage_repository import (
from app.services.implicit_memory_service import ImplicitMemoryService ImplicitEmotionsStorageRepository,
TimeFilterUnavailableError,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__) logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务") logger.info("开始执行隐性记忆和情绪数据更新定时任务")
@@ -2167,18 +2244,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
with get_db_context() as db: with get_db_context() as db:
try: try:
# 获取所有已存储数据的用户ID分批次处理
repo = ImplicitEmotionsStorageRepository(db) repo = ImplicitEmotionsStorageRepository(db)
# 先统计总数用于日志 # 先统计总数用于日志
from sqlalchemy import func from sqlalchemy import func
total_users = db.execute( total_users = db.execute(
select(func.count()).select_from(ImplicitEmotionsStorage) select(func.count()).select_from(ImplicitEmotionsStorage)
).scalar() or 0 ).scalar() or 0
logger.info(f"找到 {total_users} 个需要更新的用户") logger.info(f"表中存量用户总数: {total_users},开始时间轴筛选")
# 遍历每个用户并更新数据分批次避免一次性加载所有ID # 构建 Redis 同步客户端,用于时间轴筛选
for end_user_id in repo.get_all_user_ids(batch_size=100): _redis_client = get_sync_redis_client()
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
# Redis 不可用时回退到全量处理
try:
refresh_iter = repo.get_users_needing_refresh(_redis_client, batch_size=100)
except TimeFilterUnavailableError as e:
logger.warning(f"时间轴筛选不可用,回退到全量刷新: {e}")
refresh_iter = repo.get_all_user_ids(batch_size=100)
for end_user_id in refresh_iter:
logger.info(f"开始处理用户: {end_user_id}") logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time() user_start_time = time.time()
@@ -2264,10 +2350,10 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
user_results.append(error_info) user_results.append(error_info)
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}") logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
# ---- 处理增量用户(当天新增、尚未初始化的用户)---- # ---- 当天新增用户兜底初始化 ----
new_users_initialized = 0 new_users_initialized = 0
new_users_failed = 0 new_users_failed = 0
logger.info("开始处理当天新增的增量用户初始化") logger.info("开始处理当天新增用户的兜底初始化")
for end_user_id in repo.get_new_user_ids_today(batch_size=100): for end_user_id in repo.get_new_user_ids_today(batch_size=100):
logger.info(f"开始初始化新用户: {end_user_id}") logger.info(f"开始初始化新用户: {end_user_id}")
@@ -2281,35 +2367,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache( await implicit_service.save_profile_cache(
end_user_id=end_user_id, end_user_id=end_user_id, profile_data=profile_data, db=db
profile_data=profile_data,
db=db
) )
implicit_success = True implicit_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像") logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
except Exception as e: except Exception as e:
error_msg = f"隐性记忆初始化失败: {str(e)}" errors.append(f"隐性记忆初始化失败: {str(e)}")
errors.append(error_msg) logger.error(f"新用户 {end_user_id} 隐性记忆初始化失败: {e}")
logger.error(f"新用户 {end_user_id} {error_msg}")
try: try:
emotion_service = EmotionAnalyticsService() emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions( suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id, end_user_id=end_user_id, db=db, language="zh"
db=db,
language="zh"
) )
await emotion_service.save_suggestions_cache( await emotion_service.save_suggestions_cache(
end_user_id=end_user_id, end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
suggestions_data=suggestions_data,
db=db
) )
emotion_success = True emotion_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议") logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
except Exception as e: except Exception as e:
error_msg = f"情绪建议初始化失败: {str(e)}" errors.append(f"情绪建议初始化失败: {str(e)}")
errors.append(error_msg) logger.error(f"新用户 {end_user_id} 情绪建议初始化失败: {e}")
logger.error(f"新用户 {end_user_id} {error_msg}")
if implicit_success or emotion_success: if implicit_success or emotion_success:
new_users_initialized += 1 new_users_initialized += 1
@@ -2319,7 +2397,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
user_elapsed = time.time() - user_start_time user_elapsed = time.time() - user_start_time
user_results.append({ user_results.append({
"end_user_id": end_user_id, "end_user_id": end_user_id,
"type": "init", "type": "new_user_init",
"implicit_success": implicit_success, "implicit_success": implicit_success,
"emotion_success": emotion_success, "emotion_success": emotion_success,
"errors": errors, "errors": errors,
@@ -2331,7 +2409,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
user_elapsed = time.time() - user_start_time user_elapsed = time.time() - user_start_time
user_results.append({ user_results.append({
"end_user_id": end_user_id, "end_user_id": end_user_id,
"type": "init", "type": "new_user_init",
"implicit_success": False, "implicit_success": False,
"emotion_success": False, "emotion_success": False,
"errors": [str(e)], "errors": [str(e)],
@@ -2339,27 +2417,24 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
}) })
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}") logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
logger.info( logger.info(f"当天新增用户兜底初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}")
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}" # ---- 新增用户兜底初始化结束 ----
)
# ---- 增量用户处理结束 ----
# 记录总体统计信息
logger.info( logger.info(
f"隐性记忆和情绪数据更新定时任务完成: " f"隐性记忆和情绪数据更新定时任务完成: "
f"存量用户总数={total_users}, " f"存量用户总数={total_users}, "
f"隐性记忆成功={successful_implicit}, " f"隐性记忆成功={successful_implicit}, "
f"情绪建议成功={successful_emotion}, " f"情绪建议成功={successful_emotion}, "
f"存量失败={failed}, " f"存量失败={failed}, "
f"增量初始化成功={new_users_initialized}, " f"新增用户初始化成功={new_users_initialized}, "
f"增量初始化失败={new_users_failed}" f"新增用户初始化失败={new_users_failed}"
) )
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"message": ( "message": (
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;" f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
f"量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败" f"当天新增用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
), ),
"total_users": total_users, "total_users": total_users,
"successful_implicit": successful_implicit, "successful_implicit": successful_implicit,
@@ -2367,7 +2442,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
"failed": failed, "failed": failed,
"new_users_initialized": new_users_initialized, "new_users_initialized": new_users_initialized,
"new_users_failed": new_users_failed, "new_users_failed": new_users_failed,
"user_results": user_results[:50] # 只保留前50个用户的详细结果 "user_results": user_results[:50]
} }
except Exception as e: except Exception as e:
@@ -2416,3 +2491,232 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
# =============================================================================
@celery_app.task(
name="app.tasks.init_implicit_emotions_for_users",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
triggered=True,
)
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
Args:
end_user_ids: 需要检查的用户ID列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__)
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
initialized = 0
failed = 0
skipped = 0
with get_db_context() as db:
repo = ImplicitEmotionsStorageRepository(db)
for end_user_id in end_user_ids:
existing = repo.get_by_end_user_id(end_user_id)
if existing is not None:
skipped += 1
continue
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
implicit_ok = False
emotion_ok = False
try:
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id, profile_data=profile_data, db=db
)
implicit_ok = True
except Exception as e:
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id, db=db, language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
)
emotion_ok = True
except Exception as e:
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
if implicit_ok or emotion_ok:
initialized += 1
else:
failed += 1
except Exception as e:
failed += 1
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
# =============================================================================
@celery_app.task(
name="app.tasks.init_interest_distribution_for_users",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
默认生成中文zh兴趣分布数据。
Args:
end_user_ids: 需要检查的用户ID列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__)
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0
failed = 0
skipped = 0
language = "zh"
service = MemoryAgentService()
with get_db_context() as db:
for end_user_id in end_user_ids:
# 存在性检查:缓存有数据则跳过
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
language=language,
)
if cached is not None:
skipped += 1
continue
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
try:
result = await service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=5,
language=language,
)
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
expire=INTEREST_CACHE_EXPIRE,
)
initialized += 1
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
except Exception as e:
failed += 1
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}

View File

@@ -1,4 +1,36 @@
{ {
"v0.2.6": {
"introduction": {
"codeName": "听剑",
"releaseDate": "2026-3-6",
"upgradePosition": "🐻 多模态交互全面升级,记忆剪枝与工作流迁移双线并进,锋芒初露,兼收并蓄",
"coreUpgrades": [
"1. 工作流与应用框架<br>* 工作流导入适配Dify支持 Dify 工作流定义无缝迁移<br>* 字段字数限制与校验规则:可配置字符限制与产品级校验<br>* 应用复制Agent、工作流、集群一键复制完整应用配置<br>* 对话变量(调试+分享):支持有状态多轮交互<br>* Chat 接口流式输出 message_id流式响应包含消息追踪标识",
"2. 多模态与交互 💬<br>* 音频输入与输出:应用支持音频模态<br>* 文件类型输入支持:扩展支持语音、文件、视频上传",
"3. 模型与智能 🧠<br>* 模型视觉与 Omni 区分:精确区分视觉与 Omni 模型能力<br>* 教育记忆与陪伴玩具场景预设:垂直领域本体配置开箱即用<br>* 本体配置默认标识:支持基线配置标记<br>* 记忆配置默认标识:自动应用默认记忆设置",
"4. 记忆智能 🔬<br>* 记忆剪枝模块:智能裁剪冗余低价值记忆<br>* RAG 快速检索集成记忆:深度思考与正常回复双模式检索",
"5. 稳健性与缺陷修复 🔧<br>* 模型管理:修复自定义模型 API Key 批量配置错误<br>* 知识库管理:修复非源文档下载原始内容接口错误,更新分享停用提示文案<br>* 用户记忆:优化档案提取准确性(姓名、职业、兴趣分布)<br>* 长期记忆:修复情景记忆卡片重复和用户归属错误<br>* 工作空间首页修复知识库数量、应用数量、总记忆容量、API 调用次数、知识库类型分布等数据不一致问题<br>* 基础设施:修正 Celery 环境变量配置,修复数据库连接池 idle-in-transaction 泄漏",
"<br>",
"v0.2.6 标志着 MemoryBear 在多模态交互、跨平台工作流迁移和智能记忆管理方面的重要突破。下一版本将聚焦 A2A 协议支持实现多智能体协作、多模态记忆能力扩展至语音与视觉领域,以及应用导入导出功能支持跨环境便携部署。",
"MemoryBear让记忆有熊力 🐻✨"
]
},
"introduction_en": {
"codeName": "TingJian",
"releaseDate": "2026-3-6",
"upgradePosition": "🐻 Full multimodal interaction upgrade with memory pruning and workflow migration — sharpened edge, broader reach",
"coreUpgrades": [
"1. Workflow & Application Framework<br>* Workflow Import Adaptation (Dify): Seamless Dify workflow migration<br>* Field Character Limits & Validation: Configurable limits with product-defined rules<br>* Application Cloning (Agent, Workflow, Cluster): One-click full config duplication<br>* Conversation Variables (Debug + Share): Stateful multi-turn interactions<br>* Streaming message_id in Chat API: Message tracking in streaming responses",
"2. Multimodal & Interaction 💬<br>* Audio Input & Output: Audio modality support for applications<br>* File Type Input Support: Voice, file, and video upload support",
"3. Model & Intelligence 🧠<br>* Model Vision & Omni Differentiation: Precise capability routing<br>* Education Memory & Companion Toy Presets: Domain-specific ontology configs<br>* Ontology Default Identifier: Baseline configuration flagging<br>* Memory Configuration Default Identifier: Auto-apply default settings",
"4. Memory Intelligence 🔬<br>* Memory Pruning Module: Intelligent trimming of redundant memories<br>* RAG Quick Retrieval with Memory: Deep think and normal reply dual-mode retrieval",
"5. Robustness & Bug Fixes 🔧<br>* Model Management: Fixed custom model API key batch configuration error<br>* Knowledge Base: Fixed download original content API error for non-source documents, updated share disable prompt text<br>* User Memory: Improved profile extraction accuracy (name, occupation, interests)<br>* Long-Term Memory: Fixed duplicate episodic memory cards and wrong user attribution<br>* Dashboard: Fixed data inconsistencies in knowledge count, app count, memory capacity, API calls, and knowledge type distribution<br>* Infrastructure: Corrected Celery environment variables, fixed database connection pool idle-in-transaction leak",
"<br>",
"v0.2.6 marks a significant milestone for MemoryBear in multimodal interaction, cross-platform workflow migration, and intelligent memory management. The next release will focus on A2A protocol support for multi-agent collaboration, multimodal memory extending extraction to voice and visual domains, and application import/export for portable cross-environment deployment.",
"MemoryBear, Memory with Bear Power 🐻✨"
]
}
},
"v0.2.5": { "v0.2.5": {
"introduction": { "introduction": {
"codeName": "行云", "codeName": "行云",

View File

@@ -49,7 +49,7 @@ services:
networks: networks:
- celery - celery
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency) # Periodic worker - Scheduled/beat tasks + API-triggered tasks (prefork, low concurrency)
worker-periodic: worker-periodic:
image: redbear-mem-open:latest image: redbear-mem-open:latest
container_name: worker-periodic container_name: worker-periodic

View File

@@ -29,10 +29,10 @@ REDIS_DB=
REDIS_PASSWORD=password REDIS_PASSWORD=password
#celery #celery
BROKER_URL= # NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
RESULT_BACKEND= # 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
CELERY_BROKER= REDIS_DB_CELERY_BROKER=
CELERY_BACKEND= REDIS_DB_CELERY_BACKEND=
# Memory Cache Regeneration Configuration # Memory Cache Regeneration Configuration
# Interval in hours for regenerating memory insight and user summary cache # Interval in hours for regenerating memory insight and user summary cache

View File

@@ -0,0 +1,36 @@
"""202603061644
Revision ID: 1ac07dc7366f
Revises: 6a4641cf192b
Create Date: 2026-03-06 16:51:10.152305
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1ac07dc7366f'
down_revision: Union[str, None] = '6a4641cf192b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('mcp_tool_configs', sa.Column('source_channel', sa.String(length=50), server_default=sa.text("'self_hosted'"), nullable=False, comment='来源渠道'))
op.add_column('mcp_tool_configs', sa.Column('market_id', sa.UUID(), nullable=True, comment='渠道市场id'))
op.add_column('mcp_tool_configs', sa.Column('market_config_id', sa.UUID(), nullable=True, comment='渠道市场配置id'))
op.add_column('mcp_tool_configs', sa.Column('mcp_service_id', sa.String(length=255), nullable=True, comment='mcp服务id'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('mcp_tool_configs', 'mcp_service_id')
op.drop_column('mcp_tool_configs', 'market_config_id')
op.drop_column('mcp_tool_configs', 'market_id')
op.drop_column('mcp_tool_configs', 'source_channel')
# ### end Alembic commands ###

View File

@@ -1,5 +1,5 @@
import { request } from '@/utils/request' import { request } from '@/utils/request'
import type { Query, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types' import type { Query, MarketQuery, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types'
// 工具列表 // 工具列表
export const getTools = (data: Query) => { export const getTools = (data: Query) => {
@@ -33,4 +33,44 @@ export const getToolDetail = (tool_id: string) => {
} }
export const getToolMethods = (tool_id: string) => { export const getToolMethods = (tool_id: string) => {
return request.get(`/tools/${tool_id}/methods`) return request.get(`/tools/${tool_id}/methods`)
}
// MCP市场列表
export const getMarketTools = (data: Query) => {
return request.get('/mcp_markets/mcp_markets', data)
}
// 市场配置创建
export const createMarketConfig = (values: {
mcp_market_id: string;
token: string;
status: number;
}) => {
return request.post('/mcp_market_configs/mcp_market_config', values)
}
// 市场配置更新
export const updateMarketConfig = (values: {
mcp_market_config_id: string;
token: string;
status: number;
}) => {
return request.put(`/mcp_market_configs/${values.mcp_market_config_id}`, values)
}
// 市场根据id获取配置
export const getMarketConfig = (mcp_market_id: string) => {
return request.get(`/mcp_market_configs/mcp_market_id/${mcp_market_id}`)
}
// 市场MCP列表
export const getMarketMCPs = (data: MarketQuery) => {
return request.get('/mcp_market_configs/mcp_servers', data)
}
// 根据配置ID serverId 获取MCP服务详情
export const getMarketMCPDetail = (data:{
mcp_market_config_id: string;
server_id: string;
}) => {
return request.get(`/mcp_market_configs/mcp_server`,data)
}
// 市场已激活MCP列表
export const getMarketMCPsActivated = (data: MarketQuery) => {
return request.get('/mcp_market_configs/operational_mcp_servers', data)
} }

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2025-12-10 16:46:14 * @Date: 2025-12-10 16:46:14
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 18:42:49 * @Last Modified time: 2026-03-06 13:36:20
*/ */
import { type FC, useEffect, useMemo } from 'react' import { type FC, useEffect, useMemo } from 'react'
import { Flex, Input, Form } from 'antd' import { Flex, Input, Form } from 'antd'
@@ -50,13 +50,17 @@ const ChatInput: FC<ChatInputProps> = ({
const handleDelete = (file: any) => { const handleDelete = (file: any) => {
fileChange?.(fileList?.filter(item => item.uid !== file.uid) || []) fileChange?.(fileList?.filter(item => {
return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl
: item.url && file.url ? item.url !== file.url
: item.uid !== file.uid
}) || [])
} }
// Convert file object to preview URL // Convert file object to preview URL
const previewFileList = useMemo(() => { const previewFileList = useMemo(() => {
return fileList?.map(file => ({ return fileList?.map(file => ({
...file, ...file,
url: file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : file.thumbUrl) url: file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
})) || [] })) || []
}, [fileList]) }, [fileList])
@@ -72,7 +76,7 @@ const ChatInput: FC<ChatInputProps> = ({
{previewFileList.map((file) => { {previewFileList.map((file) => {
if (file.type.includes('image')) { if (file.type.includes('image')) {
return ( return (
<div key={file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg"> <div key={file.url || file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" /> <img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
<div <div
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]" className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
@@ -83,7 +87,7 @@ const ChatInput: FC<ChatInputProps> = ({
} }
if (file.type.includes('video')) { if (file.type.includes('video')) {
return ( return (
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg"> <div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> <video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
<div <div
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]" className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
@@ -94,7 +98,7 @@ const ChatInput: FC<ChatInputProps> = ({
} }
if (file.type.includes('audio')) { if (file.type.includes('audio')) {
return ( return (
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2"> <div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
<audio src={file.url} controls className="rb:w-45 rb:h-16" /> <audio src={file.url} controls className="rb:w-45 rb:h-16" />
<div <div
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]" className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
@@ -104,7 +108,7 @@ const ChatInput: FC<ChatInputProps> = ({
) )
} }
return ( return (
<div key={file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5"> <div key={file.url || file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div {(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]" className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
></div>} ></div>}

View File

@@ -440,7 +440,6 @@ export const en = {
logoutApiCannotRefreshToken: 'Logout API cannot refresh token', logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
publicApiCannotRefreshToken: 'Public API cannot refresh token', publicApiCannotRefreshToken: 'Public API cannot refresh token',
refreshTokenNotExist: 'Refresh token does not exist', refreshTokenNotExist: 'Refresh token does not exist',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
reset: 'Reset', reset: 'Reset',
refresh: 'Refresh', refresh: 'Refresh',
return: 'Return', return: 'Return',
@@ -1362,6 +1361,7 @@ export const en = {
complex: 'Compatibility Analysis', complex: 'Compatibility Analysis',
sureInfo: 'Information Confirmation', sureInfo: 'Information Confirmation',
completed: 'Import Completed', completed: 'Import Completed',
baseInfo: 'Basic Information',
workflowName: 'Workflow Name', workflowName: 'Workflow Name',
fileName: 'File Name', fileName: 'File Name',
fileSize: 'File Size', fileSize: 'File Size',
@@ -1573,7 +1573,7 @@ export const en = {
intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function', intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function',
intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).', intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).',
intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene', intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene',
intelligentSemanticPruningSceneDesc: 'Select intelligent semantic pruning scene (education, online_service, outbound).', intelligentSemanticPruningSceneDesc: 'Semantic pruning scenarios are consistent with ontology engineering scenarios',
intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold', intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold',
intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).', intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).',
reflectionEngine: 'Self-Reflexion Engine', reflectionEngine: 'Self-Reflexion Engine',
@@ -1807,6 +1807,25 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
error_desc: 'API is configured but connection error', error_desc: 'API is configured but connection error',
testConnectionSuccess: 'Test Connection Successful', testConnectionSuccess: 'Test Connection Successful',
refreshSuccess: 'Refresh Successful',
refreshFailed: 'Refresh Failed',
// Market related
marketSelectTitle: 'Select an MCP Market',
marketSelectDesc: 'Choose a market source from the left, configure the connection to browse MCP services',
marketRefreshSuccess: 'List refreshed',
marketActivated: 'Activated',
marketInDatabase: 'In Database',
marketAdd: 'Add',
marketRefresh: 'Refresh',
marketConfig: 'Configure',
marketConfigConnection: 'Configure Connection',
marketNoServices: 'No MCP Services Available',
marketNotConnected: 'Not Connected to This Market',
marketNoServicesDesc: 'This market currently has no available services',
marketNotConnectedDesc: 'Click the "Configure" button in the upper right corner to set connection information',
marketSearchPlaceholder: 'Search services...',
marketVisit: 'Visit Market',
serviceEndpoint: 'Service Endpoint URL', serviceEndpoint: 'Service Endpoint URL',
serviceEndpointPlaceholder: 'URL of the service endpoint', serviceEndpointPlaceholder: 'URL of the service endpoint',
serviceEndpointExtra: 'Complete access address of the MCP service', serviceEndpointExtra: 'Complete access address of the MCP service',
@@ -1960,6 +1979,19 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
viewDetail: 'View Details', viewDetail: 'View Details',
textLink: 'Test Connection', textLink: 'Test Connection',
noResult: 'Processing results will be displayed here', noResult: 'Processing results will be displayed here',
marketConfig: 'Configure {{name}}',
marketSaveAndConnect: 'Save & Connect',
marketUrl: 'Market URL',
marketUrlPlaceholder: 'Market URL',
marketCopy: 'Copy',
marketApiKeyOptional: 'Optional',
marketApiKeyExtra: 'Some markets require an API Key to access the full service list',
marketApiKeyPlaceholder: 'Enter API Key to access more services',
marketConnectionStatus: 'Connection Status',
marketConnected: '● Connected',
marketDisconnected: '○ Disconnected',
marketConnecting: 'Connecting to {{name}}...',
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces', serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore', requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
}, },
@@ -2008,6 +2040,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
self_optimization: 'Self Optimization', self_optimization: 'Self Optimization',
process_evolution: 'Process Evolution', process_evolution: 'Process Evolution',
unknown: 'Unknown Node', unknown: 'Unknown Node',
notes: 'Sticky Note',
clickToConfigure: 'Click to configure node parameters', clickToConfigure: 'Click to configure node parameters',
nodeProperties: 'Node Properties', nodeProperties: 'Node Properties',
@@ -2195,6 +2228,12 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
output_variables: 'Output Variables', output_variables: 'Output Variables',
refreshTip: 'Sync function signature to code', refreshTip: 'Sync function signature to code',
}, },
notes: {
showAuth: 'Show Author',
enterLink: 'Enter Link URL',
placeholder: 'Enter note...',
removeLink: 'Remove Link',
},
name: 'Key', name: 'Key',
type: 'Type', type: 'Type',
value: 'Value', value: 'Value',
@@ -2617,6 +2656,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
updated_at: 'Updated At', updated_at: 'Updated At',
entityTypes: 'Entity Types', entityTypes: 'Entity Types',
classSearchPlaceholder: 'Search types',
addClass: 'Add Type', addClass: 'Add Type',
class_name: 'Type Name', class_name: 'Type Name',
class_description: 'Type Definition', class_description: 'Type Definition',

View File

@@ -96,7 +96,7 @@ export const zh = {
createMemorySummary: '创建记忆摘要', createMemorySummary: '创建记忆摘要',
memoryManagement: '记忆管理', memoryManagement: '记忆管理',
spaceManagement: '空间管理', spaceManagement: '空间管理',
memoryExtractionEngine: '记忆取引擎', memoryExtractionEngine: '记忆取引擎',
forgettingEngine: '遗忘引擎', forgettingEngine: '遗忘引擎',
apiKeyManagement: 'API KEY管理', apiKeyManagement: 'API KEY管理',
knowledgePrivate: '详情', knowledgePrivate: '详情',
@@ -1020,7 +1020,6 @@ export const zh = {
logoutApiCannotRefreshToken: '退出登录接口不能刷新token', logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
publicApiCannotRefreshToken: '公共接口不能刷新token', publicApiCannotRefreshToken: '公共接口不能刷新token',
refreshTokenNotExist: '刷新token不存在', refreshTokenNotExist: '刷新token不存在',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
reset: '重置', reset: '重置',
refresh: '刷新', refresh: '刷新',
return: '返回', return: '返回',
@@ -1284,7 +1283,7 @@ export const zh = {
createConfiguration: '创建配置', createConfiguration: '创建配置',
editConfiguration: '编辑配置', editConfiguration: '编辑配置',
desc: '描述', desc: '描述',
memoryExtractionEngine: '记忆取引擎', memoryExtractionEngine: '记忆取引擎',
forgottenEngine: '遗忘引擎', forgottenEngine: '遗忘引擎',
active: '活跃', active: '活跃',
inactive: '不活跃', inactive: '不活跃',
@@ -1572,7 +1571,7 @@ export const zh = {
intelligentSemanticPruningFunction: '智能语义修剪功能', intelligentSemanticPruningFunction: '智能语义修剪功能',
intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪true/false。', intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪true/false。',
intelligentSemanticPruningScene: '智能语义修剪场景', intelligentSemanticPruningScene: '智能语义修剪场景',
intelligentSemanticPruningSceneDesc: '选择智能语义修剪场景education、online_service、outbound', intelligentSemanticPruningSceneDesc: '语义剪枝场景与本体工程场景一致',
intelligentSemanticPruningThreshold: '智能语义修剪阈值', intelligentSemanticPruningThreshold: '智能语义修剪阈值',
intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值0-0.9)。', intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值0-0.9)。',
reflectionEngine: '自我反思引擎', reflectionEngine: '自我反思引擎',
@@ -1804,6 +1803,25 @@ export const zh = {
error_desc: 'API 已配置但链接异常', error_desc: 'API 已配置但链接异常',
testConnectionSuccess: '测试连接成功', testConnectionSuccess: '测试连接成功',
refreshSuccess: '刷新成功',
refreshFailed: '刷新失败',
// Market 相关
marketSelectTitle: '选择一个 MCP 市场',
marketSelectDesc: '从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务',
marketRefreshSuccess: '列表已刷新',
marketActivated: '已激活',
marketInDatabase: '已入库',
marketAdd: '添加',
marketRefresh: '刷新',
marketConfig: '配置',
marketConfigConnection: '配置连接',
marketNoServices: '暂无可用的 MCP 服务',
marketNotConnected: '尚未连接此市场',
marketNoServicesDesc: '该市场暂时没有可用的服务',
marketNotConnectedDesc: '点击右上角"配置"按钮设置连接信息',
marketSearchPlaceholder: '搜索服务...',
marketVisit: '前往市场',
serviceEndpoint: '服务端点 URL', serviceEndpoint: '服务端点 URL',
serviceEndpointPlaceholder: '服务端点的 URL', serviceEndpointPlaceholder: '服务端点的 URL',
serviceEndpointExtra: 'MCP服务的完整访问地址', serviceEndpointExtra: 'MCP服务的完整访问地址',
@@ -1957,6 +1975,19 @@ export const zh = {
viewDetail: '查看详情', viewDetail: '查看详情',
textLink: '测试连接', textLink: '测试连接',
noResult: '处理结果将显示在这里', noResult: '处理结果将显示在这里',
marketConfig: '配置 {{name}}',
marketSaveAndConnect: '保存并连接',
marketUrl: '市场地址',
marketUrlPlaceholder: '市场地址',
marketCopy: '复制',
marketApiKeyOptional: '可选',
marketApiKeyExtra: '部分市场需要 API Key 才能获取完整的服务列表',
marketApiKeyPlaceholder: '输入 API Key 以获取更多服务',
marketConnectionStatus: '连接状态',
marketConnected: '● 已连接',
marketDisconnected: '○ 未连接',
marketConnecting: '正在连接 {{name}}...',
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格', serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾', requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
}, },
@@ -2005,6 +2036,7 @@ export const zh = {
self_optimization: '自我优化', self_optimization: '自我优化',
process_evolution: '流程演化', process_evolution: '流程演化',
unknown: '未知节点', unknown: '未知节点',
notes: '便签',
clickToConfigure: '点击配置节点参数', clickToConfigure: '点击配置节点参数',
nodeProperties: '节点属性', nodeProperties: '节点属性',
@@ -2195,6 +2227,12 @@ export const zh = {
unknown: { unknown: {
replaceNodeType: '替换节点' replaceNodeType: '替换节点'
}, },
notes: {
showAuth: '显示作者',
enterLink: '输入链接 URL',
placeholder: '输入注释...',
removeLink: '取消链接',
},
name: '键', name: '键',
type: '类型', type: '类型',
value: '值', value: '值',
@@ -2618,6 +2656,7 @@ export const zh = {
updated_at: '更新时间', updated_at: '更新时间',
entityTypes: '实体类型', entityTypes: '实体类型',
classSearchPlaceholder: '搜索类型',
addClass: '添加类型', addClass: '添加类型',
class_name: '类型名称', class_name: '类型名称',
class_description: '类型定义', class_description: '类型定义',

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-02 16:35:15 * @Date: 2026-02-02 16:35:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-02 16:35:15 * @Last Modified time: 2026-03-06 10:39:00
*/ */
/** /**
* HTTP Request Utility Module * HTTP Request Utility Module
@@ -183,7 +183,7 @@ service.interceptors.response.use(
msg = msg || i18n.t('common.serverError'); msg = msg || i18n.t('common.serverError');
break; break;
default: default:
if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') { if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
msg = i18n.t(`common.${msg}`) msg = i18n.t(`common.${msg}`)
} else if (!msg && Array.isArray(error.response?.data?.detail)) { } else if (!msg && Array.isArray(error.response?.data?.detail)) {
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';') msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
@@ -356,12 +356,11 @@ export const request = {
* Get parent domain for cookie setting * Get parent domain for cookie setting
* @returns Parent domain or IP address * @returns Parent domain or IP address
*/ */
const isIp = (hostname: string) => /^\d+\.\d+\.\d+\.\d+$/.test(hostname)
const getParentDomain = () => { const getParentDomain = () => {
const hostname = window.location.hostname const hostname = window.location.hostname
// Check if it's an IP address if (isIp(hostname)) return hostname
if (/^\d+\.\d+\.\d+\.\d+$/.test(hostname)) {
return hostname
}
const parts = hostname.split('.') const parts = hostname.split('.')
return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname
} }
@@ -371,7 +370,10 @@ const getParentDomain = () => {
*/ */
export const cookieUtils = { export const cookieUtils = {
set: (name: string, value: string, domain = getParentDomain()) => { set: (name: string, value: string, domain = getParentDomain()) => {
document.cookie = `${name}=${value}; domain=${domain}; path=/; secure; samesite=strict` const ip = isIp(window.location.hostname)
const domainPart = ip ? '' : `; domain=${domain}`
const securePart = window.location.protocol === 'https:' ? '; secure' : ''
document.cookie = `${name}=${value}${domainPart}; path=/${securePart}; samesite=strict`
}, },
get: (name: string) => { get: (name: string) => {
const value = `; ${document.cookie}` const value = `; ${document.cookie}`

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 16:27:39 * @Date: 2026-02-03 16:27:39
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 18:51:20 * @Last Modified time: 2026-03-05 17:03:46
*/ */
/** /**
* Chat debugging component for application testing * Chat debugging component for application testing
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
.then(() => { .then(() => {
const message = msg const message = msg
if (!message?.trim()) return if (!message?.trim()) return
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
return
}
addUserMessage(message, fileList) addUserMessage(message, fileList)
setMessage(message) setMessage(message)
@@ -198,29 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
}; };
setTimeout(() => { setTimeout(() => {
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
return
}
runCompare(data.app_id, { runCompare(data.app_id, {
message, message,
files: fileList.map(file => { files: fileList.map(file => {

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-28 14:08:14 * @Date: 2026-02-28 14:08:14
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-02 17:39:49 * @Last Modified time: 2026-03-06 12:05:46
*/ */
/** /**
* UploadWorkflowModal Component * UploadWorkflowModal Component
@@ -101,6 +101,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
formData.append('platform', values.platform); formData.append('platform', values.platform);
formData.append('file', values.file[0]); formData.append('file', values.file[0]);
setLoading(true)
// Call import workflow API // Call import workflow API
importWorkflow(formData) importWorkflow(formData)
.then(res => { .then(res => {
@@ -114,21 +115,24 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
} else { } else {
setCurrent(2); setCurrent(2);
// Pre-fill form with file information // Pre-fill form with file information
const fileNameSplit = values.file[0].name.split('.')
form.setFieldsValue({ form.setFieldsValue({
name: values.file[0].name.split('.')[0], name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
platform: values.platform, platform: values.platform,
fileName: values.file[0].name, fileName: values.file[0].name,
fileSize: values.file[0].size, fileSize: values.file[0].size,
}); });
} }
}); })
.finally(() => setLoading(false));
break; break;
case 1: // Step 2: Error/warning display case 1: // Step 2: Error/warning display
if (firstFormData) { if (firstFormData) {
const { file, platform } = firstFormData; const { file, platform } = firstFormData;
const fileNameSplit = firstFormData.file[0].name.split('.')
// Pre-fill form with file information // Pre-fill form with file information
form.setFieldsValue({ form.setFieldsValue({
name: file[0].name.split('.')[0], name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
platform: platform, platform: platform,
fileName: file[0].name, fileName: file[0].name,
fileSize: file[0].size, fileSize: file[0].size,
@@ -138,6 +142,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
break; break;
case 2: // Step 3: Confirm information case 2: // Step 3: Confirm information
if (data) { if (data) {
setLoading(true);
// Complete import workflow // Complete import workflow
completeImportWorkflow({ completeImportWorkflow({
temp_id: data.temp_id, temp_id: data.temp_id,
@@ -148,7 +153,8 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
const response = res as { id: string }; const response = res as { id: string };
setCurrent(3); setCurrent(3);
setAppId(response.id); setAppId(response.id);
}); })
.finally(() => setLoading(false));
} }
break; break;
default: default:
@@ -175,7 +181,9 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
} }
// Reset form if not going back to error/warning step // Reset form if not going back to error/warning step
if (newStep !== 1) { if (newStep === 0) {
form.setFieldsValue(firstFormData || {})
} else if (newStep !== 1) {
form.resetFields(); form.resetFields();
} }
setCurrent(newStep); setCurrent(newStep);
@@ -186,14 +194,16 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
* @param {string} type - Navigation type ('detail' or 'list') * @param {string} type - Navigation type ('detail' or 'list')
*/ */
const handleJump = (type: string) => { const handleJump = (type: string) => {
switch(type) {
case 'detail':
// Open application detail page in new tab
window.open(`/#/application/config/${appId}`, '_blank');
break;
}
refresh();
handleClose(); handleClose();
refresh();
setTimeout(() => {
switch (type) {
case 'detail':
// Open application detail page in new tab
window.open(`/#/application/config/${appId}`, '_blank');
break;
}
}, 100)
}; };
/** /**
@@ -235,7 +245,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
</Button> </Button>
]; ];
} }
}, [current]); }, [current, loading]);
return ( return (
<RbModal <RbModal
@@ -350,7 +360,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
title={t('application.importSuccess')} title={t('application.importSuccess')}
subTitle={t('application.importSuccessDesc')} subTitle={t('application.importSuccessDesc')}
extra={[ extra={[
<Button key="back" onClick={() => handleJump('list')}> <Button key="back" onClick={() => handleJump('list')}>
{t('application.gotoList')} {t('application.gotoList')}
</Button>, </Button>,
<Button <Button

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-06 21:09:42 * @Date: 2026-02-06 21:09:42
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 15:09:22 * @Last Modified time: 2026-03-06 12:20:43
*/ */
/** /**
* File Upload Component * File Upload Component
@@ -208,6 +208,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
newFileList.map(file => { newFileList.map(file => {
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document' const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
file.type = type file.type = type
file.thumbUrl = file.thumbUrl || URL.createObjectURL(file.originFileObj as Blob)
}) })
setFileList(newFileList); setFileList(newFileList);
if (onChange) { if (onChange) {

View File

@@ -82,6 +82,7 @@ const CreateDataset = () => {
const [form] = Form.useForm<ContentFormData>(); const [form] = Form.useForm<ContentFormData>();
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]); const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds); const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
const [textFormValid, setTextFormValid] = useState<boolean>(false);
const [pollingLoading, setPollingLoading] = useState<boolean>(false); const [pollingLoading, setPollingLoading] = useState<boolean>(false);
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null); const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
@@ -624,7 +625,16 @@ const CreateDataset = () => {
)} )}
{source && source === 'text' && ( {source && source === 'text' && (
<div className='rb:flex rb:w-full rb:flex-col rb:mt-10 rb:px-40'> <div className='rb:flex rb:w-full rb:flex-col rb:mt-10 rb:px-40'>
<Form form={form} layout="vertical"> <Form
form={form}
layout="vertical"
onValuesChange={() => {
// 检查表单字段是否都已填写
const values = form.getFieldsValue();
const isValid = !!(values.title?.trim() && values.content?.trim());
setTextFormValid(isValid);
}}
>
<Form.Item <Form.Item
name="title" name="title"
label={t('knowledgeBase.title')} label={t('knowledgeBase.title')}
@@ -845,7 +855,11 @@ const CreateDataset = () => {
<Button <Button
type='primary' type='primary'
onClick={current === 2 ? handleStartUpload : handleNext} onClick={current === 2 ? handleStartUpload : handleNext}
disabled={pollingLoading || (current === 0 && rechunkFileIds.length === 0)} disabled={
pollingLoading ||
(current === 0 && source === 'local' && rechunkFileIds.length === 0) ||
(current === 0 && source === 'text' && !textFormValid)
}
> >
{current === 2 ? t('knowledgeBase.startUploading') || 'Start Upload' : t('common.next') || 'Next'} {current === 2 ? t('knowledgeBase.startUploading') || 'Start Upload' : t('common.next') || 'Next'}
</Button> </Button>

View File

@@ -672,9 +672,17 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
{currentType !== 'Folder' && dynamicTypeList.map((tp) => { {currentType !== 'Folder' && dynamicTypeList.map((tp) => {
const fieldKey = typeToFieldKey(tp); const fieldKey = typeToFieldKey(tp);
// When tp is 'llm', merge llm and chat options // When tp is 'llm', merge llm and chat options
const options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text' let options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])] ? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
: modelOptionsByType[tp] || []; : modelOptionsByType[tp] || [];
// When tp is 'image2text', filter to only include models with 'vision' capability
if (tp.toLowerCase() === 'image2text') {
options = options.filter((opt: any) => {
const model = models?.items?.find((m: any) => m.id === opt.value);
return model?.capability?.includes('vision');
});
}
return ( return (
<Form.Item <Form.Item
key={tp} key={tp}

View File

@@ -4,7 +4,7 @@
* @Author: yujiangping * @Author: yujiangping
* @Date: 2025-11-10 18:52:55 * @Date: 2025-11-10 18:52:55
* @LastEditors: yujiangping * @LastEditors: yujiangping
* @LastEditTime: 2026-03-03 14:46:08 * @LastEditTime: 2026-03-09 16:39:07
*/ */
import { forwardRef, useImperativeHandle, useState, useRef } from 'react'; import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
import { Switch } from 'antd'; import { Switch } from 'antd';
@@ -58,16 +58,21 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
} }
const handleShare = async() => { const handleShare = async() => {
const workspaceIds = spaceList setLoading(true);
.map(item => item.target_kb?.workspace_id) try {
.filter(Boolean) const workspaceIds = spaceList
.join(','); .map(item => item.target_kb?.workspace_id)
.filter(Boolean)
console.log('Workspace IDs:', workspaceIds); .join(',');
shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
console.log('Workspace IDs:', workspaceIds);
// Close modal after sharing shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
handleClose();
// Close modal after sharing
handleClose();
} finally {
setLoading(false);
}
} }
const handleChange = (checked: boolean, item: any) => { const handleChange = (checked: boolean, item: any) => {
// Toggle shared knowledge base status // Toggle shared knowledge base status

View File

@@ -4,7 +4,7 @@
* @Author: yujiangping * @Author: yujiangping
* @Date: 2025-11-10 18:52:55 * @Date: 2025-11-10 18:52:55
* @LastEditors: yujiangping * @LastEditors: yujiangping
* @LastEditTime: 2025-12-03 18:44:58 * @LastEditTime: 2026-03-09 16:34:51
*/ */
import { forwardRef, useImperativeHandle, useState } from 'react'; import { forwardRef, useImperativeHandle, useState } from 'react';
import { Switch } from 'antd'; import { Switch } from 'antd';
@@ -50,34 +50,38 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
setSpaceList(filteredItems as SpaceItem[]); setSpaceList(filteredItems as SpaceItem[]);
} }
const handleShare = async() => { const handleShare = async() => {
// Get all data with checked = true // Get all data with checked = true
const checkedItems = spaceList.filter(item => item.is_active); const checkedItems = spaceList.filter(item => item.is_active);
debugger
// Get currently selected item (corresponding to curIndex) // Get currently selected item (corresponding to curIndex)
const selectedItem = curIndex !== -1 ? spaceList[curIndex] : null; const selectedItem = curIndex !== -1 ? spaceList[curIndex] : null;
if(!selectedItem){ if(!selectedItem){
messageApi.error(t('knowledgeBase.selectSpace')); messageApi.error(t('knowledgeBase.selectSpace'));
return; return;
} }
const payload = {
source_kb_id: kbId ?? '',
target_workspace_id: selectedItem?.id ?? '',
}
const respose = await shareKnowledgeBase(payload)
if(respose){
messageApi.success(t('knowledgeBase.shareSuccess'));
}else{
messageApi.error(t('knowledgeBase.shareFailed'));
}
// Call parent component's callback function with selected data
onShare?.({
checkedItems,
selectedItem
});
// Close modal after sharing setLoading(true);
handleClose(); try {
const payload = {
source_kb_id: kbId ?? '',
target_workspace_id: selectedItem?.id ?? '',
}
const respose = await shareKnowledgeBase(payload)
if(respose){
messageApi.success(t('knowledgeBase.shareSuccess'));
}else{
messageApi.error(t('knowledgeBase.shareFailed'));
}
// Call parent component's callback function with selected data
onShare?.({
checkedItems,
selectedItem
});
// Close modal after sharing
handleClose();
} finally {
setLoading(false);
}
} }
const handleClick = (index: number, checked: boolean) => { const handleClick = (index: number, checked: boolean) => {
if (!checked) return; if (!checked) return;

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:30:06 * @Date: 2026-02-03 17:30:06
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 10:09:45 * @Last Modified time: 2026-03-06 13:49:00
*/ */
/** /**
* Memory Extraction Engine Configuration Constants * Memory Extraction Engine Configuration Constants
@@ -140,13 +140,8 @@ export const configList: ConfigVo[] = [
{ {
label: 'intelligentSemanticPruningScene', label: 'intelligentSemanticPruningScene',
variableName: 'pruning_scene', variableName: 'pruning_scene',
control: 'select', control: 'text',
type: 'enum', type: 'enum',
options: [
{ label: 'education', value: 'education' },
{ label: 'online_service', value: 'online_service' },
{ label: 'outbound', value: 'outbound' },
],
meaning: 'intelligentSemanticPruningSceneDesc', meaning: 'intelligentSemanticPruningSceneDesc',
}, },
// Intelligent semantic pruning阈值 // Intelligent semantic pruning阈值

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:30:02 * @Date: 2026-02-03 17:30:02
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:30:02 * @Last Modified time: 2026-03-06 13:50:05
*/ */
/** /**
* Memory Extraction Engine Configuration Page * Memory Extraction Engine Configuration Page
@@ -13,7 +13,7 @@
import { type FC, useState, useEffect } from 'react' import { type FC, useState, useEffect } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom' import { useParams } from 'react-router-dom'
import { Row, Col, Space, Select, InputNumber, Slider, App, Form } from 'antd' import { Row, Col, Space, Select, InputNumber, Slider, App, Form, Input } from 'antd'
import clsx from 'clsx' import clsx from 'clsx'
import Card from './components/Card' import Card from './components/Card'
@@ -35,15 +35,15 @@ const keys = [
/** /**
* Configuration description component * Configuration description component
*/ */
const ConfigDesc: FC<{ config: Variable, className?: string }> = ({config, className}) => { const ConfigDesc: FC<{ config: Variable, className?: string; onlyMeaning?: boolean; }> = ({ config, className, onlyMeaning = false}) => {
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<div className={className}> <div className={className}>
<Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}> {!onlyMeaning && <Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>
{config.variableName && <span className="rb:font-regular">{t('memoryExtractionEngine.variableName')}: {config.variableName}</span>} {config.variableName && <span className="rb:font-regular">{t('memoryExtractionEngine.variableName')}: {config.variableName}</span>}
{config.control && <span className="rb:font-regular">{t('memoryExtractionEngine.control')}: {t(`memoryExtractionEngine.${config.control}`)}</span>} {config.control && <span className="rb:font-regular">{t('memoryExtractionEngine.control')}: {t(`memoryExtractionEngine.${config.control}`)}</span>}
{config.type && <span className="rb:font-regular">{t('memoryExtractionEngine.type')}: {config.type}</span>} {config.type && <span className="rb:font-regular">{t('memoryExtractionEngine.type')}: {config.type}</span>}
</Space> </Space>}
{config.meaning && <div className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>{t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}</div>} {config.meaning && <div className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>{t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}</div>}
</div> </div>
) )
@@ -253,6 +253,21 @@ const MemoryExtractionEngine: FC = () => {
</div> </div>
</> </>
} }
{config.control === 'text' &&
<>
<div className="rb:text-[14px] rb:font-medium rb:leading-5 rb:mt-6 rb:mb-2">
-{t(`memoryExtractionEngine.${config.label}`)}
</div>
<div className="rb:pl-2">
<Form.Item
name={config.variableName}
>
<Input placeholder={t('common.pleaseEnter')} disabled />
</Form.Item>
<ConfigDesc config={config} onlyMeaning={true} className="rb:-mt-4!" />
</div>
</>
}
</div> </div>
))} ))}
</div> </div>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:33:15 * @Date: 2026-02-03 17:33:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 16:28:58 * @Last Modified time: 2026-03-06 13:53:53
*/ */
/** /**
* Memory Management Page * Memory Management Page
@@ -154,10 +154,10 @@ const MemoryManagement: React.FC = () => {
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
onClick={() => handleEdit(item)} onClick={() => handleEdit(item)}
></div> ></div>
<div {!item.is_system_default && <div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={() => handleDelete(item)} onClick={() => handleDelete(item)}
></div> ></div>}
</Space> </Space>
</div> </div>
</RbCard> </RbCard>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 16:49:45 * @Date: 2026-02-03 16:49:45
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 11:50:47 * @Last Modified time: 2026-03-06 12:26:12
*/ */
/** /**
* Model List Detail Drawer * Model List Detail Drawer
@@ -144,7 +144,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
{item.name[0]} {item.name[0]}
</div> </div>
} }
extra={<Switch defaultChecked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />} extra={<Switch checked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!" bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!"
> >
<Tooltip title={item.description}> <Tooltip title={item.description}>
@@ -153,7 +153,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
<div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6"> <div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6">
<Row gutter={12}> <Row gutter={12}>
<Col span={12}> <Col span={12}>
<Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button> {!item.model_id && <Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>}
</Col> </Col>
<Col span={12}> <Col span={12}>
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button> <Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 16:50:18 * @Date: 2026-02-03 16:50:18
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 11:39:20 * @Last Modified time: 2026-03-06 12:26:11
*/ */
/** /**
* Type definitions for Model Management * Type definitions for Model Management
@@ -121,6 +121,7 @@ export interface ModelApiKey {
* Model list item data structure * Model list item data structure
*/ */
export interface ModelListItem { export interface ModelListItem {
model_id?: string;
/** Model name */ /** Model name */
model_name?: string; model_name?: string;
/** Associated model config IDs */ /** Associated model config IDs */

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:24 * @Date: 2026-02-03 14:10:24
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 18:02:13 * @Last Modified time: 2026-03-06 11:25:59
*/ */
import { type FC, type ReactNode } from 'react'; import { type FC, type ReactNode } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
@@ -17,7 +17,7 @@ const { Header } = Layout;
*/ */
interface ConfigHeaderProps { interface ConfigHeaderProps {
/** Page title/name */ /** Page title/name */
name?: string; name?: string | ReactNode;
/** Subtitle content displayed below the title */ /** Subtitle content displayed below the title */
subTitle?: ReactNode | string; subTitle?: ReactNode | string;
/** Extra content displayed on the right side */ /** Extra content displayed on the right side */

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:15 * @Date: 2026-02-03 14:10:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 16:28:53 * @Last Modified time: 2026-03-06 10:56:44
*/ */
import { type FC, useState, useRef, type MouseEvent } from 'react'; import { type FC, useState, useRef, type MouseEvent } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
@@ -181,8 +181,8 @@ const Ontology: FC = () => {
)} )}
</Flex> </Flex>
<div className="rb:mt-4 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end"> <div className="rb:mt-4 rb:h-5 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
<Space size={16}> {!item.is_system_default && <Space size={16}>
<div <div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
onClick={(e) => handleEdit(item, e)} onClick={(e) => handleEdit(item, e)}
@@ -191,7 +191,7 @@ const Ontology: FC = () => {
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={(e) => handleDelete(item, e)} onClick={(e) => handleDelete(item, e)}
></div> ></div>
</Space> </Space>}
</div> </div>
</RbCard> </RbCard>
)} )}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:20 * @Date: 2026-02-03 14:10:20
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 17:56:35 * @Last Modified time: 2026-03-06 11:26:49
*/ */
import { type FC, useEffect, useState, useRef } from 'react' import { type FC, useEffect, useState, useRef } from 'react'
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
@@ -17,6 +17,7 @@ import OntologyClassModal from '../components/OntologyClassModal'
import SearchInput from '@/components/SearchInput'; import SearchInput from '@/components/SearchInput';
import OntologyClassExtractModal from '../components/OntologyClassExtractModal' import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
import BodyWrapper from '@/components/Empty/BodyWrapper' import BodyWrapper from '@/components/Empty/BodyWrapper'
import Tag from '@/components/Tag'
/** /**
* Ontology detail page component * Ontology detail page component
@@ -99,19 +100,22 @@ const Detail: FC = () => {
return ( return (
<> <>
<PageHeader <PageHeader
name={data.scene_name} name={<Space>
{data.scene_name}
{data.is_system_default ? <Tag color="warning">{t('common.default')}</Tag> : undefined}
</Space>}
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>} subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
extra={<Space> extra={data.is_system_default ? undefined : (<Space>
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button> <Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button> <Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
</Space>} </Space>)}
/> />
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4"> <div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
<Row gutter={16} className="rb:mb-4"> <Row gutter={16} className="rb:mb-4">
<Col span={6} offset={18}> <Col span={6} offset={18}>
<SearchInput <SearchInput
placeholder={t('ontology.searchPlaceholder')} placeholder={t('ontology.classSearchPlaceholder')}
onSearch={(value) => setQuery({ class_name: value })} onSearch={(value) => setQuery({ class_name: value })}
className="rb:w-full!" className="rb:w-full!"
/> />
@@ -123,10 +127,10 @@ const Detail: FC = () => {
<Col key={item.class_id} span={6}> <Col key={item.class_id} span={6}>
<RbCard <RbCard
title={item.class_name} title={item.class_name}
extra={<div extra={data.is_system_default ? undefined : (<div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={() => handleDelete(item)} onClick={() => handleDelete(item)}
></div>} ></div>)}
className="rb:bg-transparent!" className="rb:bg-transparent!"
> >
<Tooltip title={item.class_description}> <Tooltip title={item.class_description}>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:10 * @Date: 2026-02-03 14:10:10
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 16:18:56 * @Last Modified time: 2026-03-06 10:55:23
*/ */
/** /**
* Query parameters for ontology list pagination and filtering * Query parameters for ontology list pagination and filtering
@@ -94,6 +94,7 @@ export interface OntologyClassData {
scene_description: string; scene_description: string;
/** Array of class items */ /** Array of class items */
items: OntologyClassItem[]; items: OntologyClassItem[];
is_system_default: boolean;
} }
/** /**

View File

@@ -1,131 +1,296 @@
import React, { useState, useRef, type ReactNode } from 'react'; import React, { useState, useRef, useEffect, useCallback, type ReactNode } from 'react';
import { Input, Button, Spin, App } from 'antd'; import { Input, Button, App, Card, Space, Skeleton, Tag } from 'antd';
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons'; import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import InfiniteScroll from 'react-infinite-scroll-component';
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal'; import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
import McpServiceModal from './components/McpServiceModal';
import type { McpServiceModalRef } from './types';
import pageEmptyIcon from '@/assets/images/empty/pageEmpty.png'
import Empty from '@/components/Empty/index'
import { getMarketTools, getMarketConfig, getMarketMCPs, getMarketMCPDetail, getMarketMCPsActivated, getTools } from '@/api/tools';
import BodyWrapper from '@/components/Empty/BodyWrapper';
interface MarketSource { interface MarketSource {
id: string; id: string;
name: string; name: string;
category: string; category: string;
icon: string; logo_url: string;
url: string; url: string;
desc: string; description: string;
apiKey: string; api_key?: string;
connected: boolean; connected: boolean;
mcpCount: number; mcp_count: number;
created_at?: number;
created_by?: string;
} }
interface MarketMcp { interface MarketMcp {
id: string; id: string;
name: string; name: string;
provider: string; chinese_name?: string;
type: string; description: string;
desc: string; logo_url: string;
downloads?: string; publisher: string;
stars?: string; categories?: string[];
icon: string; tags?: string[];
configTemplate: any; view_count?: number;
activated?: boolean;
inDatabase?: boolean;
locales?: {
[lang: string]: {
name: string;
description: string;
};
};
} }
interface MarketCategory { interface MarketCategory {
id: string; id: string;
name: string; name: string;
icon: string; }
interface MarketApiResponse {
items: MarketSource[];
} }
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => { const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
const { t } = useTranslation(); const { t, i18n } = useTranslation();
const { message } = App.useApp(); const { message } = App.useApp();
const getLocaleField = (mcp: MarketMcp, field: 'name' | 'description') => {
const lang = i18n.language?.startsWith('zh') ? 'zh' : 'en';
return mcp.locales?.[lang]?.[field] || mcp[field] || '';
};
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [selectedSource, setSelectedSource] = useState<string | null>(null); const [selectedSource, setSelectedSource] = useState<string | null>(null);
const marketConfigModalRef = useRef<MarketConfigModalRef>(null); const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
const [marketSources, setMarketSources] = useState<MarketSource[]>([ const mcpServiceModalRef = useRef<McpServiceModalRef>(null);
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 }, const [marketSources, setMarketSources] = useState<MarketSource[]>([]);
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 }, const [categories, setCategories] = useState<MarketCategory[]>([]);
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 }, const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({});
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 }, const [mcpTotal, setMcpTotal] = useState(0);
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
]);
const [categories] = useState<MarketCategory[]>([
{ id: 'official', name: '官方/综合', icon: '🌐' },
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
{ id: 'community', name: '社区/垂直', icon: '👥' }
]);
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
'github-mcp': [
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
]
});
const [searchKeyword, setSearchKeyword] = useState(''); const [searchKeyword, setSearchKeyword] = useState('');
const [configIdMap, setConfigIdMap] = useState<Record<string, string>>({});
const [hasMore, setHasMore] = useState(false);
const [activatedMcps, setActivatedMcps] = useState<string[]>([]);
const [currentPage, setCurrentPage] = useState(1);
const pageSize = 20;
const handleSelectSource = (sourceId: string) => { // 获取市场数据
setSelectedSource(sourceId); useEffect(() => {
}; const fetchMarketData = async () => {
setLoading(true);
const handleRefresh = (sourceId: string) => { try {
setLoading(true); const response = await getMarketTools({}) as MarketApiResponse;
setTimeout(() => { if (response?.items && Array.isArray(response.items)) {
// 模拟刷新数据 setMarketSources(response.items);
const source = marketSources.find(s => s.id === sourceId);
if (source) { // 根据 category 字段分组
message.success(`${source.name} 列表已刷新`); const categoryMap = new Map<string, MarketCategory>();
response.items.forEach(item => {
if (item.category && !categoryMap.has(item.category)) {
categoryMap.set(item.category, {
id: item.category,
name: item.category
});
}
});
setCategories(Array.from(categoryMap.values()));
}
} catch (error) {
console.error('获取市场数据失败:', error);
message.error('获取市场数据失败');
} finally {
setLoading(false);
} }
};
fetchMarketData();
}, [message]);
const fetchMcpList = async (sourceId: string, page = 1, append = false) => {
setLoading(true);
try {
let configId = configIdMap[sourceId];
// 如果没有缓存 configId先获取配置
if (!configId) {
const config: any = await getMarketConfig(sourceId);
if (config?.id) {
configId = config.id;
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
} else {
return;
}
}
// 第一次加载时获取已激活列表
let activatedIds: string[] = activatedMcps;
if (page === 1 && !append) {
const activatedRes: any = await getMarketMCPsActivated({ mcp_market_config_id: configId });
if (activatedRes && Array.isArray(activatedRes)) {
activatedIds = activatedRes.map((item: any) => item.id);
setActivatedMcps(activatedIds);
}
}
// 获取全量工具列表,用于标记已入库的 MCP
const allTools: any = await getTools({ tool_type: 'mcp' });
const toolsList = Array.isArray(allTools) ? allTools : [];
const res: any = await getMarketMCPs({ mcp_market_config_id: configId, page, pagesize: pageSize });
if (res?.items && Array.isArray(res.items)) {
// 标记已激活和已入库的 MCP
const mcpsWithActivated = res.items.map((item: MarketMcp) => {
// 检查是否已入库market_id = sourceId, market_config_id = configId, mcp_service_id = item.id
const isInDatabase = toolsList.some((tool: any) =>
tool.config_data?.market_id === sourceId &&
tool.config_data?.market_config_id === configId &&
tool.config_data?.mcp_service_id === item.id
);
return {
...item,
activated: activatedIds.includes(item.id),
inDatabase: isInDatabase
};
});
setMcpCache(prev => ({
...prev,
[sourceId]: append ? [...(prev[sourceId] || []), ...mcpsWithActivated] : mcpsWithActivated
}));
}
if (res?.page) {
setMcpTotal(res.page.total || 0);
setHasMore(!!res.page.has_next);
setCurrentPage(res.page.page || page);
}
} catch (error) {
console.error('获取 MCP 列表失败:', error);
} finally {
setLoading(false); setLoading(false);
}, 600); }
}; };
const handleOpenConfig = (sourceId: string) => { const loadMore = useCallback(() => {
if (!selectedSource || loading) return;
fetchMcpList(selectedSource, currentPage + 1, true);
}, [selectedSource, currentPage, loading]);
const handleSelectSource = async (sourceId: string) => {
setSelectedSource(sourceId);
setSearchKeyword('');
setCurrentPage(1);
setHasMore(false);
setMcpTotal(0);
// 如果缓存中已有数据,直接使用
if (mcpCache[sourceId]) return;
await fetchMcpList(sourceId, 1);
};
const handleRefresh = async (sourceId: string) => {
// 清除缓存,重新从第一页加载
setMcpCache(prev => {
const next = { ...prev };
delete next[sourceId];
return next;
});
setCurrentPage(1);
await fetchMcpList(sourceId, 1);
const source = marketSources.find(s => s.id === sourceId); const source = marketSources.find(s => s.id === sourceId);
if (source) { if (source) {
message.success(`${source.name} ${t('tool.marketRefreshSuccess')}`);
}
};
const handleOpenConfig = async (sourceId: string) => {
const source = marketSources.find(s => s.id === sourceId);
if (!source) return;
try {
const config: any = await getMarketConfig(sourceId);
marketConfigModalRef.current?.handleOpen({
...source,
connected: config?.status === 1,
token: config?.token || '',
configId: config?.id || '',
});
} catch {
marketConfigModalRef.current?.handleOpen(source); marketConfigModalRef.current?.handleOpen(source);
} }
}; };
const handleConnect = (sourceId: string, apiKey: string) => { const handleOpenMcpServiceModal = async (mcp: MarketMcp) => {
// 更新市场源状态 if (!selectedSource || !configIdMap[selectedSource]) return;
try {
const detail: any = await getMarketMCPDetail({
mcp_market_config_id: configIdMap[selectedSource],
server_id: mcp.id,
});
const source = marketSources.find(s => s.id === selectedSource);
const toolItem = {
name: detail.name,
description: detail.description,
source_channel: source?.name || '',
market_id: selectedSource,
market_config_id: configIdMap[selectedSource],
mcp_service_id: mcp.id,
config_data: {
server_url: detail.servers?.[0]?.url || '',
connection_config: {
auth_type: 'none',
timeout: 30,
headers: {},
},
},
};
mcpServiceModalRef.current?.handleOpen(toolItem as any);
} catch (error) {
console.error('获取 MCP 服务详情失败:', error);
}
};
const handleConnect = async (sourceId: string, configId: string) => {
// 更新市场源状态,缓存 configId
setMarketSources(prev => prev.map(source => { setMarketSources(prev => prev.map(source => {
if (source.id === sourceId) { if (source.id === sourceId) {
return { return { ...source, connected: true };
...source,
apiKey,
connected: true
};
} }
return source; return source;
})); }));
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
// 模拟获取MCP列表 // 用 configId 获取第一页 MCP 列表
setTimeout(() => { try {
const source = marketSources.find(s => s.id === sourceId); const res: any = await getMarketMCPs({ mcp_market_config_id: configId, page: 1, pagesize: pageSize });
if (source && !mcpCache[sourceId]) { if (res?.items && Array.isArray(res.items)) {
// 生成模拟数据 setMcpCache(prev => ({ ...prev, [sourceId]: res.items }));
const mockData: MarketMcp[] = [
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
];
setMcpCache(prev => ({
...prev,
[sourceId]: mockData
}));
} }
message.success(`已连接 ${source?.name}`); if (res?.page) {
}, 800); setMcpTotal(res.page.total || 0);
setHasMore(!!res.page.has_next);
setCurrentPage(1);
}
} catch (error) {
console.error('获取 MCP 列表失败:', error);
}
}; };
const renderSourceDetail = () => { const renderSourceDetail = () => {
if (!selectedSource) { if (!selectedSource) {
return ( return (
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center"> <div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
<div className="rb:text-6xl rb:mb-4">🏪</div> <Empty
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2"> MCP </h3> url={pageEmptyIcon}
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md"> MCP </p> title={t('tool.marketSelectTitle')}
subTitle={t('tool.marketSelectDesc')}
size={200}
className="rb:h-full"
/>
</div> </div>
); );
} }
@@ -134,170 +299,218 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
if (!source) return null; if (!source) return null;
const mcpList = mcpCache[selectedSource] || []; const mcpList = mcpCache[selectedSource] || [];
const filteredList = mcpList.filter(mcp => const filteredList = mcpList.filter(mcp => {
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) || const name = getLocaleField(mcp, 'name');
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase()) const desc = getLocaleField(mcp, 'description');
); return name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
desc.toLowerCase().includes(searchKeyword.toLowerCase());
});
return ( return (
<> <>
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6"> <div className="rb:flex rb:justify-between rb:items-center rb:pb-0">
<div className="rb:flex rb:gap-4"> <div className="rb:flex rb:items-center rb:gap-4">
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0"> <div className="rb:w-10 rb:h-10 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0 rb:overflow-hidden">
{source.icon} {source.logo_url ? (
<img
src={source.logo_url}
alt={source.name}
className="rb:w-full rb:h-full rb:object-cover"
referrerPolicy="no-referrer"
onError={(e) => {
e.currentTarget.style.display = 'none';
const parent = e.currentTarget.parentElement;
if (parent) {
parent.innerHTML = '🏪';
parent.style.fontSize = '48px';
}
}}
/>
) : (
<span className="rb:text-5xl">🏪</span>
)}
</div> </div>
<div className="rb:flex-1"> <div className="rb:flex rb:items-center rb:flex-1">
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2> <h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2 rb:mr-2">{source.name}</h2>
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p> MCP <span className="rb:text-gray-600 rb:font-normal">({mcpTotal})</span>
{/* <p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.description}</p> */}
</div> </div>
</div> </div>
<div className="rb:flex rb:gap-3">
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
</Button>
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
</Button>
</div>
</div>
<div className="rb:mt-6"> <div className="rb:flex rb:gap-3">
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
MCP <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
</h3>
<div className="rb:flex rb:gap-3 rb:items-center"> <div className="rb:flex rb:gap-3 rb:items-center">
{source.connected && ( {source.connected && (
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}> <Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
{t('tool.marketRefresh')}
</Button> </Button>
)} )}
{mcpList.length > 0 && ( {mcpList.length > 0 && (
<Input <Input
prefix={<SearchOutlined />} prefix={<SearchOutlined />}
placeholder="搜索服务..." placeholder={t('tool.marketSearchPlaceholder')}
value={searchKeyword} value={searchKeyword}
onChange={(e) => setSearchKeyword(e.target.value)} onChange={(e) => setSearchKeyword(e.target.value)}
style={{ width: 200 }} style={{ width: 200 }}
/> />
)} )}
</div> </div>
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
{t('tool.marketConfig')}
</Button>
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
{t('tool.marketVisit')}
</Button>
</div> </div>
</div>
{mcpList.length > 0 ? ( <div className="rb:mt-6">
<Spin spinning={loading}> <BodyWrapper loading={loading} empty={mcpList.length === 0}>
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4"> <div id="mcpScrollableDiv" className="rb:overflow-y-auto rb:h-[calc(100vh-260px)]">
{filteredList.map(mcp => ( <InfiniteScroll
dataLength={filteredList.length}
next={loadMore}
hasMore={hasMore}
loader={<Skeleton active paragraph={{ rows: 2 }} className="rb:mt-4" />}
scrollableTarget="mcpScrollableDiv"
>
<div className="rb:grid rb:grid-cols-3 rb:gap-4">
{filteredList.map(mcp => (
<div <div
key={mcp.id} key={mcp.id}
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300" className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:pb-2 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
> >
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3"> <div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg"> <div className="rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg rb:overflow-hidden">
{mcp.icon} {mcp.logo_url ? (
<img
src={mcp.logo_url}
alt={getLocaleField(mcp, 'name')}
className="rb:w-full rb:h-full rb:object-cover"
referrerPolicy="no-referrer"
onError={(e) => {
e.currentTarget.style.display = 'none';
const parent = e.currentTarget.parentElement;
if (parent) {
parent.innerHTML = '🔧';
parent.style.fontSize = '24px';
}
}}
/>
) : (
<span className="rb:text-3xl">🔧</span>
)}
</div> </div>
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${ {mcp.categories?.[0] && (
mcp.type === 'Hosted' <span className="rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium rb:bg-blue-50 rb:text-blue-700">
? 'rb:bg-blue-50 rb:text-blue-700' {mcp.categories[0]}
: 'rb:bg-gray-100 rb:text-gray-600' </span>
}`}> )}
{mcp.type}
</span>
</div> </div>
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3> <h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{getLocaleField(mcp, 'name')}</h3>
{mcp.provider && ( {mcp.publisher && (
<div className="rb:mb-2"> <div className="rb:mb-2">
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span> <span className="rb:text-xs rb:text-gray-500">{mcp.publisher.startsWith('@') ? mcp.publisher : `@${mcp.publisher}`}</span>
</div> </div>
)} )}
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p> <p className="rb:text-sm rb:text-gray-600 rb:line-clamp-2 rb:mb-3 rb:min-h-10">{getLocaleField(mcp, 'description')}</p>
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100"> <div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
{mcp.downloads && ( {mcp.view_count != null && (
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500"> <span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
<GlobalOutlined /> {mcp.downloads} <GlobalOutlined /> {mcp.view_count.toLocaleString()}
</span>
)}
{mcp.stars && (
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
{mcp.stars}
</span> </span>
)} )}
</div> </div>
<div className="rb:flex rb:justify-end"> <div className={`rb:flex rb:items-center ${mcp.activated || mcp.inDatabase ? 'rb:justify-between' : 'rb:justify-end'}`}>
<Button type="primary" size="small"> <div className="rb:flex rb:gap-2">
+ {mcp.activated && <Tag color="success">{t('tool.marketActivated')}</Tag>}
{mcp.inDatabase && <Tag color="blue">{t('tool.marketInDatabase')}</Tag>}
</div>
<Button type="primary" size="small" onClick={() => handleOpenMcpServiceModal(mcp)}>
+ {t('tool.marketAdd')}
</Button> </Button>
</div> </div>
</div> </div>
))} ))}
</div> </div>
</Spin> </InfiniteScroll>
) : (
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
</h4>
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
</p>
{!source.connected && (
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
</Button>
)}
</div> </div>
)} </BodyWrapper>
</div> </div>
</> </>
); );
}; };
return ( return (
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]"> <div className="rb:flex rb:gap-4 rb:h-[calc(100vh-138px)]">
{/* 左侧市场源列表 */} {/* 左侧市场源列表 */}
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0"> <div className="rb:w-80 rb:h-full rb:overflow-y-auto">
<div className="rb:p-4 rb:border-b rb:border-gray-200"> <Space size={12} direction="vertical" className="rb:w-full">
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP </span> {categories.map(cat => (
</div> <Card
{categories.map(cat => ( key={cat.id}
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0"> type="inner"
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase"> title={
<span className="rb:text-sm">{cat.icon}</span> <div className="rb:flex rb:items-center rb:gap-2">
<span>{cat.name}</span> <span>{cat.name}</span>
</div> </div>
<div className="rb:px-2 rb:py-1"> }
{marketSources classNames={{
.filter(s => s.category === cat.id) body: "rb:p-[10px]!",
.map(source => ( header: "rb:bg-[#F6F8FC]!"
<div }}
key={source.id} >
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${ <Space size={8} direction="vertical" className="rb:w-full">
selectedSource === source.id {marketSources
? 'rb:bg-blue-50 rb:text-blue-600' .filter(s => s.category === cat.id)
: 'hover:rb:bg-gray-50' .map(source => (
}`} <div
onClick={() => handleSelectSource(source.id)} key={source.id}
> className={`rb:bg-white rb:rounded-lg rb:p-2 rb:border rb:cursor-pointer rb:flex rb:items-center rb:gap-2 rb:transition-all ${
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span> selectedSource === source.id
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap"> ? 'rb:border-[#155EEF] rb:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
{source.name} : 'rb:border-[#DFE4ED] rb:hover:border-[#155EEF] rb:hover:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
</span> }`}
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full"> onClick={() => handleSelectSource(source.id)}
{source.mcpCount} >
</span> <div className="rb:w-5 rb:h-5 rb:flex-shrink-0 rb:flex rb:items-center rb:justify-center rb:overflow-hidden rb:rounded rb:bg-gray-100">
{source.connected && ( {source.logo_url ? (
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1"></span> <img
)} src={source.logo_url}
</div> alt={source.name}
))} className="rb:w-full rb:h-full rb:object-cover"
</div> referrerPolicy="no-referrer"
</div> onError={(e) => {
))} e.currentTarget.style.display = 'none';
const parent = e.currentTarget.parentElement;
if (parent) {
parent.innerHTML = '🏪';
parent.style.fontSize = '16px';
}
}}
/>
) : (
<span className="rb:text-base">🏪</span>
)}
</div>
<span className="rb:flex-1 rb:font-medium rb:text-[12px] rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
{source.name}
</span>
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
{source.mcp_count}
</span>
{source.connected && (
<span className="rb:text-green-500 rb:text-[8px] rb:flex-shrink-0"></span>
)}
</div>
))}
</Space>
</Card>
))}
</Space>
</div> </div>
{/* 右侧内容区 */} {/* 右侧内容区 */}
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden"> <div className="rb:flex-1 rb:border-l rb:border-gray-200 rb:overflow-hidden">
<div className="rb:h-full rb:overflow-y-auto rb:p-6"> <div className="rb:h-full rb:overflow-y-auto rb:p-6">
{renderSourceDetail()} {renderSourceDetail()}
</div> </div>
@@ -308,6 +521,10 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
ref={marketConfigModalRef} ref={marketConfigModalRef}
onConnect={handleConnect} onConnect={handleConnect}
/> />
<McpServiceModal
ref={mcpServiceModalRef}
refresh={() => {}}
/>
</div> </div>
); );
}; };

View File

@@ -61,7 +61,6 @@ const Mcp: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getSta
getData() getData()
}) })
}; };
// 删除服务 // 删除服务
const handleDeleteService = (item: ToolItem) => { const handleDeleteService = (item: ToolItem) => {
if (!item.id) { if (!item.id) {

View File

@@ -2,6 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input, Button, App, Space } from 'antd'; import { Form, Input, Button, App, Space } from 'antd';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons'; import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
import { createMarketConfig,updateMarketConfig } from '@/api/tools';
import RbModal from '@/components/RbModal'; import RbModal from '@/components/RbModal';
const FormItem = Form.Item; const FormItem = Form.Item;
@@ -9,15 +10,16 @@ const FormItem = Form.Item;
interface MarketSource { interface MarketSource {
id: string; id: string;
name: string; name: string;
icon: string; logo_url: string;
url: string; url: string;
desc: string; description: string;
apiKey: string; token?: string;
connected: boolean; connected: boolean;
configId?: string;
} }
interface MarketConfigModalProps { interface MarketConfigModalProps {
onConnect: (sourceId: string, apiKey: string) => void; onConnect: (sourceId: string, configId: string) => void;
} }
export interface MarketConfigModalRef { export interface MarketConfigModalRef {
@@ -47,8 +49,7 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
const handleOpen = (source: MarketSource) => { const handleOpen = (source: MarketSource) => {
setCurrentSource(source); setCurrentSource(source);
form.setFieldsValue({ form.setFieldsValue({
url: source.url, token: source.token || '',
apiKey: source.apiKey,
}); });
setVisible(true); setVisible(true);
}; };
@@ -56,18 +57,36 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
const handleSave = () => { const handleSave = () => {
form form
.validateFields() .validateFields()
.then((values) => { .then(async (values) => {
if (!currentSource) return; if (!currentSource) return;
setLoading(true); setLoading(true);
try {
// 模拟连接延迟 let res: any;
setTimeout(() => { if (currentSource.configId) {
onConnect(currentSource.id, values.apiKey || ''); // 更新配置
message.success(`正在连接 ${currentSource.name}...`); res = await updateMarketConfig({
setLoading(false); mcp_market_config_id: currentSource.configId,
token: values.token || '',
status: 1,
});
message.success(t('tool.marketConfigUpdated', { name: currentSource.name }));
} else {
// 创建配置
res = await createMarketConfig({
mcp_market_id: currentSource.id || '',
token: values.token || '',
status: 1,
});
message.success(t('tool.marketConnecting', { name: currentSource.name }));
}
onConnect(currentSource.id, res.id || currentSource.configId);
handleClose(); handleClose();
}, 500); } catch (error) {
console.error('保存配置失败:', error);
} finally {
setLoading(false);
}
}) })
.catch((err) => { .catch((err) => {
console.log('表单验证失败:', err); console.log('表单验证失败:', err);
@@ -91,10 +110,10 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
return ( return (
<RbModal <RbModal
title={`配置 ${currentSource.name}`} title={t('tool.marketConfig', { name: currentSource.name })}
open={visible} open={visible}
onCancel={handleClose} onCancel={handleClose}
okText="保存并连接" okText={t('tool.marketSaveAndConnect')}
onOk={handleSave} onOk={handleSave}
confirmLoading={loading} confirmLoading={loading}
width={600} width={600}
@@ -102,12 +121,28 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
<div> <div>
{/* 市场源信息头部 */} {/* 市场源信息头部 */}
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg"> <div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0"> <div className="rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0 rb:overflow-hidden">
{currentSource.icon} {currentSource.logo_url ? (
<img
src={currentSource.logo_url}
alt={currentSource.name}
className="rb:w-full rb:h-full rb:object-cover"
onError={(e) => {
e.currentTarget.style.display = 'none';
const parent = e.currentTarget.parentElement;
if (parent) {
parent.innerHTML = '🏪';
parent.style.fontSize = '32px';
}
}}
/>
) : (
<span className="rb:text-4xl">🏪</span>
)}
</div> </div>
<div className="rb:flex-1"> <div className="rb:flex-1">
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3> <h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p> <p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.description}</p>
</div> </div>
</div> </div>
@@ -115,39 +150,34 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
form={form} form={form}
layout="vertical" layout="vertical"
> >
{/* 市场地址 */} <FormItem label={t('tool.marketUrl')}>
<FormItem
name="url"
label="市场地址"
>
<Space.Compact style={{ width: '100%' }}> <Space.Compact style={{ width: '100%' }}>
<Input <Input
readOnly readOnly
placeholder="市场地址" value={currentSource.url}
/> />
<Button <Button
icon={<CopyOutlined />} icon={<CopyOutlined />}
onClick={handleCopyUrl} onClick={handleCopyUrl}
> >
{t('tool.marketCopy')}
</Button> </Button>
</Space.Compact> </Space.Compact>
</FormItem> </FormItem>
{/* API Key */}
<FormItem <FormItem
name="apiKey" name="token"
label={ label={
<span> <span>
API Key <span className="rb:text-gray-400 rb:font-normal">()</span> API Key <span className="rb:text-gray-400 rb:font-normal">({t('tool.marketApiKeyOptional')})</span>
</span> </span>
} }
extra="部分市场需要 API Key 才能获取完整的服务列表" extra={<span style={{ display: 'inline-block', marginTop: 8 }}>{t('tool.marketApiKeyExtra')}</span>}
> >
<Space.Compact style={{ width: '100%' }}> <Space.Compact style={{ width: '100%' }}>
<Input <Input
type={showApiKey ? 'text' : 'password'} type={showApiKey ? 'text' : 'password'}
placeholder="输入 API Key 以获取更多服务" placeholder={t('tool.marketApiKeyPlaceholder')}
autoComplete="off" autoComplete="off"
/> />
<Button <Button
@@ -157,11 +187,10 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
</Space.Compact> </Space.Compact>
</FormItem> </FormItem>
{/* 连接状态 */}
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm"> <div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
<span className="rb:text-gray-600"></span> <span className="rb:text-gray-600">{t('tool.marketConnectionStatus')}</span>
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}> <span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
{currentSource.connected ? '● 已连接' : '○ 未连接'} {currentSource.connected ? t('tool.marketConnected') : t('tool.marketDisconnected')}
</span> </span>
</div> </div>
</Form> </Form>

Some files were not shown because too many files have changed in this diff Show More