Merge branch 'feature/memory_zy' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/memory_zy
This commit is contained in:
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (Using Redis as broker)
|
# Celery (Using Redis as broker)
|
||||||
BROKER_URL=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BROKER=1
|
||||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BACKEND=2
|
||||||
|
|
||||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (使用Redis作为broker)
|
# Celery (使用Redis作为broker)
|
||||||
BROKER_URL=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BROKER=1
|
||||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BACKEND=2
|
||||||
|
|
||||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
2
api/app/cache/memory/__init__.py
vendored
2
api/app/cache/memory/__init__.py
vendored
@@ -4,7 +4,9 @@ Memory 缓存模块
|
|||||||
提供记忆系统相关的缓存功能
|
提供记忆系统相关的缓存功能
|
||||||
"""
|
"""
|
||||||
from .interest_memory import InterestMemoryCache
|
from .interest_memory import InterestMemoryCache
|
||||||
|
from .activity_stats_cache import ActivityStatsCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InterestMemoryCache",
|
"InterestMemoryCache",
|
||||||
|
"ActivityStatsCache",
|
||||||
]
|
]
|
||||||
|
|||||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Recent Activity Stats Cache
|
||||||
|
|
||||||
|
记忆提取活动统计缓存模块
|
||||||
|
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||||
|
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.aioRedis import aio_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 缓存过期时间:24小时
|
||||||
|
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityStatsCache:
|
||||||
|
"""记忆提取活动统计缓存类"""
|
||||||
|
|
||||||
|
PREFIX = "cache:memory:activity_stats"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_key(cls, workspace_id: str) -> str:
|
||||||
|
"""生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的 Redis key
|
||||||
|
"""
|
||||||
|
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set_activity_stats(
|
||||||
|
cls,
|
||||||
|
workspace_id: str,
|
||||||
|
stats: Dict[str, Any],
|
||||||
|
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||||
|
) -> bool:
|
||||||
|
"""设置记忆提取活动统计缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
stats: 统计数据,格式:
|
||||||
|
{
|
||||||
|
"chunk_count": int,
|
||||||
|
"statements_count": int,
|
||||||
|
"triplet_entities_count": int,
|
||||||
|
"triplet_relations_count": int,
|
||||||
|
"temporal_count": int,
|
||||||
|
}
|
||||||
|
expire: 过期时间(秒),默认24小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否设置成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(workspace_id)
|
||||||
|
payload = {
|
||||||
|
"stats": stats,
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"cached": True,
|
||||||
|
}
|
||||||
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
|
await aio_redis.set(key, value, ex=expire)
|
||||||
|
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_activity_stats(
|
||||||
|
cls,
|
||||||
|
workspace_id: str,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取记忆提取活动统计缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计数据字典,缓存不存在或已过期返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(workspace_id)
|
||||||
|
value = await aio_redis.get(key)
|
||||||
|
if value:
|
||||||
|
payload = json.loads(value)
|
||||||
|
logger.info(f"命中活动统计缓存: {key}")
|
||||||
|
return payload
|
||||||
|
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_activity_stats(
|
||||||
|
cls,
|
||||||
|
workspace_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""删除记忆提取活动统计缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(workspace_id)
|
||||||
|
result = await aio_redis.delete(key)
|
||||||
|
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||||
|
return result > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
@@ -1,27 +1,54 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from celery.schedules import crontab
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery.schedules import crontab
|
from celery.schedules import crontab
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB 0)
|
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||||
# backend: 结果存储(使用 Redis DB 10)
|
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||||
|
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
|
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||||
|
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||||
|
# cannot be overridden by stray env vars.
|
||||||
|
# See: https://github.com/celery/celery/issues/4284
|
||||||
|
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
|
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||||
|
# integration and accidentally override our canonical URLs.
|
||||||
|
os.environ.pop("BROKER_URL", None)
|
||||||
|
os.environ.pop("RESULT_BACKEND", None)
|
||||||
|
os.environ.pop("CELERY_BROKER", None)
|
||||||
|
os.environ.pop("CELERY_BACKEND", None)
|
||||||
|
|
||||||
celery_app = Celery(
|
celery_app = Celery(
|
||||||
"redbear_tasks",
|
"redbear_tasks",
|
||||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
broker=_broker_url,
|
||||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
backend=_backend_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Celery app initialized",
|
||||||
|
extra={
|
||||||
|
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||||
|
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||||
|
},
|
||||||
|
)
|
||||||
# Default queue for unrouted tasks
|
# Default queue for unrouted tasks
|
||||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||||
|
|
||||||
@@ -86,6 +113,8 @@ celery_app.conf.update(
|
|||||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import io
|
||||||
from typing import Optional, Annotated
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
|
|||||||
from app.services.app_statistics_service import AppStatisticsService
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
from app.services.workflow_import_service import WorkflowImportService
|
from app.services.workflow_import_service import WorkflowImportService
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
from app.services.app_dsl_service import AppDslService
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -1010,3 +1013,57 @@ def get_workspace_api_statistics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result)
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def export_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
release_id: Optional[uuid.UUID] = None
|
||||||
|
):
|
||||||
|
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||||
|
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||||
|
"""
|
||||||
|
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||||
|
encoded = quote(filename, safe=".")
|
||||||
|
yaml_bytes = yaml_str.encode("utf-8")
|
||||||
|
file_stream = io.BytesIO(yaml_bytes)
|
||||||
|
file_stream.seek(0)
|
||||||
|
return StreamingResponse(
|
||||||
|
file_stream,
|
||||||
|
media_type="application/octet-stream; charset=utf-8",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||||
|
"Content-Length": str(len(yaml_bytes))}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def import_app(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||||
|
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||||
|
"""
|
||||||
|
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||||
|
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
raw = (await file.read()).decode("utf-8")
|
||||||
|
dsl = yaml.safe_load(raw)
|
||||||
|
if not dsl or "app" not in dsl:
|
||||||
|
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
new_app, warnings = AppDslService(db).import_dsl(
|
||||||
|
dsl=dsl,
|
||||||
|
workspace_id=current_user.current_workspace_id,
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||||
|
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,28 +1,29 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.repositories import knowledge_repository
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import task_service, workspace_service
|
from app.services import task_service, workspace_service
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -37,7 +38,7 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/health/status", response_model=ApiResponse)
|
@router.get("/health/status", response_model=ApiResponse)
|
||||||
async def get_health_status(
|
async def get_health_status(
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get latest health status written by Celery periodic task
|
Get latest health status written by Celery periodic task
|
||||||
@@ -55,8 +56,9 @@ async def get_health_status(
|
|||||||
|
|
||||||
@router.get("/download_log")
|
@router.get("/download_log")
|
||||||
async def download_log(
|
async def download_log(
|
||||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||||
current_user: User = Depends(get_current_user)
|
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download or stream agent service log file
|
Download or stream agent service log file
|
||||||
@@ -75,16 +77,16 @@ async def download_log(
|
|||||||
- transmission mode: StreamingResponse with SSE
|
- transmission mode: StreamingResponse with SSE
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||||
|
|
||||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||||
if log_type not in ["file", "transmission"]:
|
if log_type not in ["file", "transmission"]:
|
||||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||||
return fail(
|
return fail(
|
||||||
BizCode.BAD_REQUEST,
|
BizCode.BAD_REQUEST,
|
||||||
"无效的log_type参数",
|
"无效的log_type参数",
|
||||||
"log_type必须是'file'或'transmission'"
|
"log_type必须是'file'或'transmission'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Route to appropriate mode
|
# Route to appropriate mode
|
||||||
if log_type == "file":
|
if log_type == "file":
|
||||||
# File mode: Return complete log file content
|
# File mode: Return complete log file content
|
||||||
@@ -119,10 +121,10 @@ async def download_log(
|
|||||||
@router.post("/writer_service", response_model=ApiResponse)
|
@router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server(
|
async def write_server(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write service endpoint - processes write operations synchronously
|
Write service endpoint - processes write operations synchronously
|
||||||
@@ -136,11 +138,11 @@ async def write_server(
|
|||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
config_id = user_input.config_id
|
config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -149,7 +151,7 @@ async def write_server(
|
|||||||
)
|
)
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
user_rag_memory_id = ''
|
||||||
|
|
||||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
if storage_type == 'rag':
|
if storage_type == 'rag':
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
@@ -161,13 +163,15 @@ async def write_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
else:
|
||||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
api_logger.warning(
|
||||||
|
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
else:
|
else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
api_logger.info(
|
||||||
|
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
@@ -175,7 +179,7 @@ async def write_server(
|
|||||||
messages_list,
|
messages_list,
|
||||||
config_id,
|
config_id,
|
||||||
db,
|
db,
|
||||||
storage_type,
|
storage_type,
|
||||||
user_rag_memory_id,
|
user_rag_memory_id,
|
||||||
language
|
language
|
||||||
)
|
)
|
||||||
@@ -195,10 +199,10 @@ async def write_server(
|
|||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
@@ -213,10 +217,11 @@ async def write_server_async(
|
|||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
config_id = user_input.config_id
|
config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
api_logger.info(
|
||||||
|
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
@@ -244,7 +249,7 @@ async def write_server_async(
|
|||||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
|
||||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
@@ -254,9 +259,9 @@ async def write_server_async(
|
|||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def read_server(
|
async def read_server(
|
||||||
user_input: UserInput,
|
user_input: UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read service endpoint - processes read operations synchronously
|
Read service endpoint - processes read operations synchronously
|
||||||
@@ -291,8 +296,9 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
api_logger.info(
|
||||||
|
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
user_input.end_user_id,
|
user_input.end_user_id,
|
||||||
@@ -306,7 +312,8 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
if str(user_input.search_switch) == "2":
|
||||||
retrieve_info = result['answer']
|
retrieve_info = result['answer']
|
||||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||||
|
user_input.end_user_id)
|
||||||
query = user_input.message
|
query = user_input.message
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
@@ -319,7 +326,7 @@ async def read_server(
|
|||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
if "信息不足,无法回答" in result['answer']:
|
if "信息不足,无法回答" in result['answer']:
|
||||||
result['answer']=retrieve_info
|
result['answer'] = retrieve_info
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -335,9 +342,10 @@ async def read_server(
|
|||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
async def file_update(
|
async def file_update(
|
||||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||||
model_id:str = Form(..., description="模型ID"),
|
model_id: str = Form(..., description="模型ID"),
|
||||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
文件上传接口 - 支持图片识别
|
文件上传接口 - 支持图片识别
|
||||||
@@ -350,9 +358,6 @@ async def file_update(
|
|||||||
Returns:
|
Returns:
|
||||||
文件处理结果
|
文件处理结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_gen = get_db() # get_db 通常是一个生成器
|
|
||||||
db = next(db_gen)
|
|
||||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
apiConfig: ModelApiKey = config.api_keys[0]
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
@@ -361,7 +366,7 @@ async def file_update(
|
|||||||
for file in files:
|
for file in files:
|
||||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||||
content = await file.read()
|
content = await file.read()
|
||||||
|
|
||||||
if file.content_type and file.content_type.startswith("image/"):
|
if file.content_type and file.content_type.startswith("image/"):
|
||||||
vision_model = QWenCV(
|
vision_model = QWenCV(
|
||||||
key=apiConfig.api_key,
|
key=apiConfig.api_key,
|
||||||
@@ -375,12 +380,12 @@ async def file_update(
|
|||||||
else:
|
else:
|
||||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||||
|
|
||||||
result_text = ';'.join(file_content)
|
result_text = ';'.join(file_content)
|
||||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||||
|
|
||||||
return success(data=result_text, msg="转换文本成功")
|
return success(data=result_text, msg="转换文本成功")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||||
@@ -430,8 +435,8 @@ async def read_server_async(
|
|||||||
|
|
||||||
@router.get("/read_result/", response_model=ApiResponse)
|
@router.get("/read_result/", response_model=ApiResponse)
|
||||||
async def get_read_task_result(
|
async def get_read_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async read task
|
Get the status and result of an async read task
|
||||||
@@ -452,7 +457,7 @@ async def get_read_task_result(
|
|||||||
try:
|
try:
|
||||||
result = task_service.get_task_memory_read_result(task_id)
|
result = task_service.get_task_memory_read_result(task_id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
|
|
||||||
if status == "SUCCESS":
|
if status == "SUCCESS":
|
||||||
# 任务成功完成
|
# 任务成功完成
|
||||||
task_result = result.get("result", {})
|
task_result = result.get("result", {})
|
||||||
@@ -470,7 +475,7 @@ async def get_read_task_result(
|
|||||||
else:
|
else:
|
||||||
# 旧格式:直接返回结果
|
# 旧格式:直接返回结果
|
||||||
return success(data=task_result, msg="查询任务已完成")
|
return success(data=task_result, msg="查询任务已完成")
|
||||||
|
|
||||||
elif status == "FAILURE":
|
elif status == "FAILURE":
|
||||||
# 任务失败
|
# 任务失败
|
||||||
error_info = result.get("result", "Unknown error")
|
error_info = result.get("result", "Unknown error")
|
||||||
@@ -479,7 +484,7 @@ async def get_read_task_result(
|
|||||||
else:
|
else:
|
||||||
error_msg = str(error_info)
|
error_msg = str(error_info)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||||
|
|
||||||
elif status in ["PENDING", "STARTED"]:
|
elif status in ["PENDING", "STARTED"]:
|
||||||
# 任务进行中
|
# 任务进行中
|
||||||
return success(
|
return success(
|
||||||
@@ -499,7 +504,7 @@ async def get_read_task_result(
|
|||||||
},
|
},
|
||||||
msg=f"任务状态: {status}"
|
msg=f"任务状态: {status}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||||
@@ -507,8 +512,8 @@ async def get_read_task_result(
|
|||||||
|
|
||||||
@router.get("/write_result/", response_model=ApiResponse)
|
@router.get("/write_result/", response_model=ApiResponse)
|
||||||
async def get_write_task_result(
|
async def get_write_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async write task
|
Get the status and result of an async write task
|
||||||
@@ -529,7 +534,7 @@ async def get_write_task_result(
|
|||||||
try:
|
try:
|
||||||
result = task_service.get_task_memory_write_result(task_id)
|
result = task_service.get_task_memory_write_result(task_id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
|
|
||||||
if status == "SUCCESS":
|
if status == "SUCCESS":
|
||||||
# 任务成功完成
|
# 任务成功完成
|
||||||
task_result = result.get("result", {})
|
task_result = result.get("result", {})
|
||||||
@@ -547,7 +552,7 @@ async def get_write_task_result(
|
|||||||
else:
|
else:
|
||||||
# 旧格式:直接返回结果
|
# 旧格式:直接返回结果
|
||||||
return success(data=task_result, msg="写入任务已完成")
|
return success(data=task_result, msg="写入任务已完成")
|
||||||
|
|
||||||
elif status == "FAILURE":
|
elif status == "FAILURE":
|
||||||
# 任务失败
|
# 任务失败
|
||||||
error_info = result.get("result", "Unknown error")
|
error_info = result.get("result", "Unknown error")
|
||||||
@@ -556,7 +561,7 @@ async def get_write_task_result(
|
|||||||
else:
|
else:
|
||||||
error_msg = str(error_info)
|
error_msg = str(error_info)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||||
|
|
||||||
elif status in ["PENDING", "STARTED"]:
|
elif status in ["PENDING", "STARTED"]:
|
||||||
# 任务进行中
|
# 任务进行中
|
||||||
return success(
|
return success(
|
||||||
@@ -576,7 +581,7 @@ async def get_write_task_result(
|
|||||||
},
|
},
|
||||||
msg=f"任务状态: {status}"
|
msg=f"任务状态: {status}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||||
@@ -584,9 +589,9 @@ async def get_write_task_result(
|
|||||||
|
|
||||||
@router.post("/status_type", response_model=ApiResponse)
|
@router.post("/status_type", response_model=ApiResponse)
|
||||||
async def status_type(
|
async def status_type(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -629,9 +634,10 @@ async def status_type(
|
|||||||
|
|
||||||
@router.get("/stats/types", response_model=ApiResponse)
|
@router.get("/stats/types", response_model=ApiResponse)
|
||||||
async def get_knowledge_type_stats_api(
|
async def get_knowledge_type_stats_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||||
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
|||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
api_logger.info(
|
||||||
|
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
# 获取数据库会话
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 调用service层函数
|
# 调用service层函数
|
||||||
result = await memory_agent_service.get_knowledge_type_stats(
|
result = await memory_agent_service.get_knowledge_type_stats(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api(
|
|||||||
current_workspace_id=current_user.current_workspace_id,
|
current_workspace_id=current_user.current_workspace_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result, msg="获取知识库类型统计成功")
|
return success(data=result, msg="获取知识库类型统计成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||||
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
|
|||||||
|
|
||||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||||
async def get_interest_distribution_by_user_api(
|
async def get_interest_distribution_by_user_api(
|
||||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签
|
获取指定用户的兴趣分布标签
|
||||||
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
async def get_user_profile_api(
|
async def get_user_profile_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -756,17 +757,17 @@ async def get_user_profile_api(
|
|||||||
# ):
|
# ):
|
||||||
# """
|
# """
|
||||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# file_path: Optional path to API docs file. If None, uses default path.
|
# file_path: Optional path to API docs file. If None, uses default path.
|
||||||
|
|
||||||
# Returns:
|
# Returns:
|
||||||
# Parsed API documentation including title, meta info, and sections
|
# Parsed API documentation including title, meta info, and sections
|
||||||
# """
|
# """
|
||||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||||
# try:
|
# try:
|
||||||
# result = await memory_agent_service.get_api_docs(file_path)
|
# result = await memory_agent_service.get_api_docs(file_path)
|
||||||
|
|
||||||
# if result.get("success"):
|
# if result.get("success"):
|
||||||
# return success(msg=result["msg"], data=result["data"])
|
# return success(msg=result["msg"], data=result["data"])
|
||||||
# else:
|
# else:
|
||||||
@@ -782,9 +783,9 @@ async def get_user_profile_api(
|
|||||||
|
|
||||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||||
async def get_end_user_connected_config(
|
async def get_end_user_connected_config(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取终端用户关联的记忆配置
|
获取终端用户关联的记忆配置
|
||||||
@@ -803,9 +804,9 @@ async def get_end_user_connected_config(
|
|||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config as get_config,
|
get_end_user_connected_config as get_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = get_config(end_user_id, db)
|
result = get_config(end_user_id, db)
|
||||||
return success(data=result, msg="获取终端用户关联配置成功")
|
return success(data=result, msg="获取终端用户关联配置成功")
|
||||||
@@ -814,4 +815,4 @@ async def get_end_user_connected_config(
|
|||||||
return fail(BizCode.NOT_FOUND, str(e))
|
return fail(BizCode.NOT_FOUND, str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||||
|
|||||||
@@ -149,6 +149,21 @@ async def get_workspace_end_users(
|
|||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
|
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||||
|
try:
|
||||||
|
from app.celery_app import celery_app as _celery_app
|
||||||
|
_celery_app.send_task(
|
||||||
|
"app.tasks.init_implicit_emotions_for_users",
|
||||||
|
kwargs={"end_user_ids": end_user_ids},
|
||||||
|
)
|
||||||
|
_celery_app.send_task(
|
||||||
|
"app.tasks.init_interest_distribution_for_users",
|
||||||
|
kwargs={"end_user_ids": end_user_ids},
|
||||||
|
)
|
||||||
|
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||||
|
|
||||||
# 并发执行配置查询和记忆数量查询
|
# 并发执行配置查询和记忆数量查询
|
||||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||||
get_memory_configs(),
|
get_memory_configs(),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
@@ -85,6 +85,7 @@ def create_config(
|
|||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -99,7 +100,29 @@ def create_config(
|
|||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
result = svc.create(payload)
|
result = svc.create(payload)
|
||||||
return success(data=result, msg="创建成功")
|
return success(data=result, msg="创建成功")
|
||||||
|
except ValueError as e:
|
||||||
|
err_str = str(e)
|
||||||
|
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||||
|
config_name = err_str.split(":", 1)[1]
|
||||||
|
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Create config failed: {err_str}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||||
|
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
|
|
||||||
@@ -521,10 +544,11 @@ async def clear_hot_memory_tags_cache(
|
|||||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info("Recent activity stats requested")
|
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||||
|
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await analytics_recent_activity_stats()
|
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
|
|||||||
@@ -371,6 +371,11 @@ def update_model(
|
|||||||
|
|
||||||
if model_data.type is not None or model_data.provider is not None:
|
if model_data.type is not None or model_data.provider is not None:
|
||||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
if model_data.is_active:
|
||||||
|
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||||
|
if not active_keys:
|
||||||
|
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -289,7 +289,8 @@ async def extract_ontology(
|
|||||||
async def create_scene(
|
async def create_scene(
|
||||||
request: SceneCreateRequest,
|
request: SceneCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||||
):
|
):
|
||||||
"""创建本体场景
|
"""创建本体场景
|
||||||
|
|
||||||
@@ -360,8 +361,18 @@ async def create_scene(
|
|||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
err_str = str(e)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||||
|
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||||
@@ -661,7 +672,8 @@ async def get_scenes(
|
|||||||
async def create_class(
|
async def create_class(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||||
):
|
):
|
||||||
"""创建本体类型
|
"""创建本体类型
|
||||||
|
|
||||||
@@ -676,7 +688,7 @@ async def create_class(
|
|||||||
ApiResponse: 包含创建的类型信息
|
ApiResponse: 包含创建的类型信息
|
||||||
"""
|
"""
|
||||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||||
return await create_class_handler(request, db, current_user)
|
return await create_class_handler(request, db, current_user, x_language_type)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends, Header
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
@@ -58,7 +58,7 @@ async def scenes_handler(
|
|||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
scene_name: Optional[str] = None,
|
scene_name: Optional[str] = None,
|
||||||
page: Optional[int] = None,
|
page: Optional[int] = None,
|
||||||
page_size: Optional[int] = None,
|
pagesize: Optional[int] = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -71,14 +71,14 @@ async def scenes_handler(
|
|||||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
"""
|
"""
|
||||||
operation = "search" if scene_name else "list"
|
operation = "search" if scene_name else "list"
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Scene {operation} requested by user {current_user.id}, "
|
f"Scene {operation} requested by user {current_user.id}, "
|
||||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -105,13 +105,13 @@ async def scenes_handler(
|
|||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
if pagesize is not None and pagesize < 1:
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
# 如果只提供了page或pagesize中的一个,返回错误
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
# 模糊搜索场景(支持分页)
|
# 模糊搜索场景(支持分页)
|
||||||
@@ -119,17 +119,15 @@ async def scenes_handler(
|
|||||||
total = len(scenes)
|
total = len(scenes)
|
||||||
|
|
||||||
# 如果提供了分页参数,进行分页处理
|
# 如果提供了分页参数,进行分页处理
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
start_idx = (page - 1) * page_size
|
start_idx = (page - 1) * pagesize
|
||||||
end_idx = start_idx + page_size
|
end_idx = start_idx + pagesize
|
||||||
scenes = scenes[start_idx:end_idx]
|
scenes = scenes[start_idx:end_idx]
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -141,17 +139,16 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num
|
classes_count=type_num,
|
||||||
|
is_system_default=scene.is_system_default
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
# 计算是否有下一页
|
hasnext = (page * pagesize) < total
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=page_size,
|
pagesize=pagesize,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -165,28 +162,25 @@ async def scenes_handler(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 获取所有场景(支持分页)
|
# 获取所有场景(支持分页)
|
||||||
# 验证分页参数
|
|
||||||
if page is not None and page < 1:
|
if page is not None and page < 1:
|
||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
if pagesize is not None and pagesize < 1:
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
# 如果只提供了page或pagesize中的一个,返回错误
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -198,17 +192,16 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num
|
classes_count=type_num,
|
||||||
|
is_system_default=scene.is_system_default
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
# 计算是否有下一页
|
hasnext = (page * pagesize) < total
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=page_size,
|
pagesize=pagesize,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -238,7 +231,8 @@ async def scenes_handler(
|
|||||||
async def create_class_handler(
|
async def create_class_handler(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||||
|
|
||||||
@@ -271,8 +265,11 @@ async def create_class_handler(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if count == 1:
|
if count == 1:
|
||||||
# 单个创建
|
# 单个创建 - 先检查重名
|
||||||
class_data = classes_data[0]
|
class_data = classes_data[0]
|
||||||
|
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||||
ontology_class = service.create_class(
|
ontology_class = service.create_class(
|
||||||
scene_id=request.scene_id,
|
scene_id=request.scene_id,
|
||||||
class_name=class_data["class_name"],
|
class_name=class_data["class_name"],
|
||||||
@@ -330,12 +327,36 @@ async def create_class_handler(
|
|||||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
err_str = str(e)
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||||
|
class_name = err_str.split(":", 1)[1]
|
||||||
|
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
err_str = str(e)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||||
|
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
class_name = request.classes[0].class_name if request.classes else ""
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||||
@@ -615,6 +636,7 @@ async def classes_handler(
|
|||||||
scene_id=scene_uuid,
|
scene_id=scene_uuid,
|
||||||
scene_name=scene.scene_name,
|
scene_name=scene.scene_name,
|
||||||
scene_description=scene.scene_description,
|
scene_description=scene.scene_description,
|
||||||
|
is_system_default=scene.is_system_default,
|
||||||
items=items
|
items=items
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,12 @@ async def create_tool(
|
|||||||
):
|
):
|
||||||
"""创建工具"""
|
"""创建工具"""
|
||||||
try:
|
try:
|
||||||
|
# 将 MCP 来源字段合并进 config
|
||||||
|
if request.tool_type == ToolType.MCP:
|
||||||
|
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||||
|
val = getattr(request, key, None)
|
||||||
|
if val is not None:
|
||||||
|
request.config[key] = val
|
||||||
tool_id = service.create_tool(
|
tool_id = service.create_tool(
|
||||||
name=request.name,
|
name=request.name,
|
||||||
tool_type=request.tool_type,
|
tool_type=request.tool_type,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Dict, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
@@ -190,8 +189,12 @@ class Settings:
|
|||||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
# 详见 docs/celery-env-bug-report.md
|
||||||
|
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||||
|
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||||
|
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||||
|
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||||
|
|
||||||
# SMTP Email Configuration
|
# SMTP Email Configuration
|
||||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
structured = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
@@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
"content_length": len(content),
|
"content_length": len(content),
|
||||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||||
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
response_content = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
response_content = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
@@ -220,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
"questions_count": len(databasets),
|
"questions_count": len(databasets),
|
||||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.error(f"Problem_Extension error details: {error_details}")
|
logger.error(f"Problem_Extension error details: {error_details}")
|
||||||
|
|||||||
@@ -6,31 +6,26 @@ import os
|
|||||||
# ===== 第三方库 =====
|
# ===== 第三方库 =====
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
|
|
||||||
from app.schemas import model_schema
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from app.services.model_service import ModelConfigService
|
|
||||||
|
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
|
||||||
COUNTState,
|
|
||||||
ReadState,
|
|
||||||
deduplicate_entries,
|
|
||||||
merge_to_key_value_pairs,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||||
create_hybrid_retrieval_tool_sync,
|
create_hybrid_retrieval_tool_sync,
|
||||||
create_time_retrieval_tool,
|
create_time_retrieval_tool,
|
||||||
extract_tool_message_content,
|
extract_tool_message_content,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
|
ReadState,
|
||||||
|
deduplicate_entries,
|
||||||
|
merge_to_key_value_pairs,
|
||||||
|
)
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db = next(get_db())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
@@ -50,10 +45,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
# 收集所有内容
|
# 收集所有内容
|
||||||
content_list = []
|
content_list = []
|
||||||
|
|
||||||
# 处理重排序结果
|
# 处理重排序结果
|
||||||
reranked = results.get('reranked_results', {})
|
reranked = results.get('reranked_results', {})
|
||||||
if reranked:
|
if reranked:
|
||||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
text_parts.append(item)
|
text_parts.append(item)
|
||||||
|
|
||||||
|
|
||||||
return '\n'.join(text_parts).strip()
|
return '\n'.join(text_parts).strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
模型信息
|
模型信息
|
||||||
'''
|
'''
|
||||||
|
|
||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
end_user_id=state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original=state.get('data', '')
|
original = state.get('data', '')
|
||||||
problem_list=[]
|
problem_list = []
|
||||||
for key,values in problem_extension.items():
|
for key, values in problem_extension.items():
|
||||||
for data in values:
|
for data in values:
|
||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# 创建异步任务处理单个问题
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
send_verify = []
|
send_verify = []
|
||||||
for i, j in zip(keys, val, strict=False):
|
for i, j in zip(keys, val, strict=False):
|
||||||
if j!=['']:
|
if j != ['']:
|
||||||
send_verify.append({
|
send_verify.append({
|
||||||
"Query_small": i,
|
"Query_small": i,
|
||||||
"Answer_Small": j
|
"Answer_Small": j
|
||||||
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve':dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取end_user_id
|
# 从state中获取end_user_id
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
return await llm_infomation(state)
|
return await llm_infomation(state)
|
||||||
|
|
||||||
llm_config = await get_llm_info()
|
llm_config = await get_llm_info()
|
||||||
api_key_obj = llm_config.api_keys[0]
|
api_key_obj = llm_config.api_keys[0]
|
||||||
api_key = api_key_obj.api_key
|
api_key = api_key_obj.api_key
|
||||||
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
async with SEMAPHORE: # 限制并发
|
async with SEMAPHORE: # 限制并发
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||||
|
question)
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
# json.dump(dup_databases, f, indent=4)
|
# json.dump(dup_databases, f, indent=4)
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve': dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store
|
|||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db_session = next(get_db())
|
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""总结节点服务类"""
|
"""总结节点服务类"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
@@ -51,10 +51,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -62,25 +64,28 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
|
||||||
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||||
|
search_mode) -> str:
|
||||||
"""
|
"""
|
||||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||||
"""
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
|
|
||||||
# 构建系统提示词
|
# 构建系统提示词
|
||||||
if str(search_mode) == "0":
|
if str(search_mode) == "0":
|
||||||
system_prompt = await summary_service.template_service.render_template(
|
system_prompt = await summary_service.template_service.render_template(
|
||||||
@@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务进行结构化输出
|
# 使用优化的LLM服务进行结构化输出
|
||||||
structured = await summary_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await summary_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=response_model,
|
system_prompt=system_prompt,
|
||||||
fallback_value=None
|
response_model=response_model,
|
||||||
)
|
fallback_value=None
|
||||||
|
)
|
||||||
# 验证结构化响应
|
# 验证结构化响应
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
# 根据操作类型提取答案
|
# 根据操作类型提取答案
|
||||||
if operation_name == "summary":
|
if operation_name == "summary":
|
||||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||||
@@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
else:
|
else:
|
||||||
logger.warning("结构化响应缺少data字段")
|
logger.warning("结构化响应缺少data字段")
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
# 验证答案不为空
|
# 验证答案不为空
|
||||||
if not aimessages or aimessages.strip() == "":
|
if not aimessages or aimessages.strip() == "":
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
return aimessages
|
return aimessages
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# 尝试非结构化输出作为fallback
|
# 尝试非结构化输出作为fallback
|
||||||
try:
|
try:
|
||||||
logger.info("尝试非结构化输出作为fallback")
|
logger.info("尝试非结构化输出作为fallback")
|
||||||
@@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
fallback_message="信息不足,无法回答"
|
fallback_message="信息不足,无法回答"
|
||||||
)
|
)
|
||||||
|
|
||||||
if response and response.strip():
|
if response and response.strip():
|
||||||
# 简单清理响应
|
# 简单清理响应
|
||||||
cleaned_response = response.strip()
|
cleaned_response = response.strip()
|
||||||
@@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
if cleaned_response.startswith('```'):
|
if cleaned_response.startswith('```'):
|
||||||
lines = cleaned_response.split('\n')
|
lines = cleaned_response.split('\n')
|
||||||
cleaned_response = '\n'.join(lines[1:-1])
|
cleaned_response = '\n'.join(lines[1:-1])
|
||||||
|
|
||||||
return cleaned_response
|
return cleaned_response
|
||||||
else:
|
else:
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logger.error(f"Fallback也失败: {fallback_error}")
|
logger.error(f"Fallback也失败: {fallback_error}")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|
||||||
|
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|
||||||
storage_type=state.get("storage_type",'')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||||
data=state.get("data", '')
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
|
data = state.get("data", '')
|
||||||
input_summary = {
|
input_summary = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
retrieve={
|
retrieve = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id,
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
"_intermediate": {
|
"_intermediate": {
|
||||||
"type": "retrieval_summary",
|
"type": "retrieval_summary",
|
||||||
"title":"快速检索",
|
"title": "快速检索",
|
||||||
"summary": aimessages,
|
"summary": aimessages,
|
||||||
"query": data,
|
"query": data,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return input_summary,retrieve
|
return input_summary, retrieve
|
||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
storage_type=state.get("storage_type",'')
|
storage_type = state.get("storage_type", '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
data=state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id=state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
history = await summary_history( state)
|
history = await summary_history(state)
|
||||||
search_params = {
|
search_params = {
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type!="rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||||
|
memory_config=memory_config)
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||||
retrieve_info, question, raw_results = "", data, []
|
retrieve_info, question, raw_results = "", data, []
|
||||||
try:
|
try:
|
||||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||||
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||||
summary = summary_result[0]
|
summary = summary_result[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||||
summary= {
|
summary = {
|
||||||
"status": "fail",
|
"status": "fail",
|
||||||
"summary_result": "信息不足,无法回答",
|
"summary_result": "信息不足,无法回答",
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
except Exception:
|
except Exception:
|
||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|
||||||
retrieve=state.get("retrieve", '')
|
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||||
history = await summary_history( state)
|
retrieve = state.get("retrieve", '')
|
||||||
|
history = await summary_history(state)
|
||||||
import json
|
import json
|
||||||
with open("检索.json","w",encoding='utf-8') as f:
|
with open("检索.json", "w", encoding='utf-8') as f:
|
||||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||||
retrieve=retrieve.get("Expansion_issue", [])
|
retrieve = retrieve.get("Expansion_issue", [])
|
||||||
start=time.time()
|
start = time.time()
|
||||||
retrieve_info_str=[]
|
retrieve_info_str = []
|
||||||
for data in retrieve:
|
for data in retrieve:
|
||||||
if data=='':
|
if data == '':
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
else:
|
else:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='Answer_Small':
|
if key == 'Answer_Small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str.append(i)
|
retrieve_info_str.append(i)
|
||||||
retrieve_info_str=list(set(retrieve_info_str))
|
retrieve_info_str = list(set(retrieve_info_str))
|
||||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|||||||
except Exception:
|
except Exception:
|
||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState)-> ReadState:
|
async def Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify=state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
verify_expansion_issue=verify.get("verified_data", '')
|
verify_expansion_issue = verify.get("verified_data", '')
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
for data in verify_expansion_issue:
|
for data in verify_expansion_issue:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str+=i+'\n'
|
retrieve_info_str += i + '\n'
|
||||||
history=await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages=await summary_llm(state,history,data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
|
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
|
||||||
storage_type=state.get("storage_type", '')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
async def Summary_fails(state: ReadState) -> ReadState:
|
||||||
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify = state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
|||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages = await summary_llm(state, history, data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
result= {
|
result = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
return {"summary":result}
|
return {"summary": result}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""验证节点服务类"""
|
"""验证节点服务类"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""处理验证结果并生成输出格式"""
|
"""处理验证结果并生成输出格式"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
data = state.get('data', '')
|
data = state.get('data', '')
|
||||||
|
|
||||||
# 将 VerificationItem 对象转换为字典列表
|
# 将 VerificationItem 对象转换为字典列表
|
||||||
verified_data = []
|
verified_data = []
|
||||||
if messages_deal.expansion_issue:
|
if messages_deal.expansion_issue:
|
||||||
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
|||||||
verified_data.append(item.model_dump())
|
verified_data.append(item.model_dump())
|
||||||
elif isinstance(item, dict):
|
elif isinstance(item, dict):
|
||||||
verified_data.append(item)
|
verified_data.append(item)
|
||||||
|
|
||||||
Verify_result = {
|
Verify_result = {
|
||||||
"status": messages_deal.split_result,
|
"status": messages_deal.split_result,
|
||||||
"verified_data": verified_data,
|
"verified_data": verified_data,
|
||||||
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Verify_result
|
return Verify_result
|
||||||
|
|
||||||
|
|
||||||
async def Verify(state: ReadState):
|
async def Verify(state: ReadState):
|
||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
|
|
||||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||||
|
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
logger.info(
|
||||||
|
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||||
|
|
||||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||||
|
|
||||||
messages = {
|
messages = {
|
||||||
"Query": content,
|
"Query": content,
|
||||||
"Expansion_issue": retrieve_expansion
|
"Expansion_issue": retrieve_expansion
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info("Verify: 开始渲染模板")
|
logger.info("Verify: 开始渲染模板")
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||||
json_schema = VerificationResult.model_json_schema()
|
json_schema = VerificationResult.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await verification_service.template_service.render_template(
|
system_prompt = await verification_service.template_service.render_template(
|
||||||
template_name='split_verify_prompt.jinja2',
|
template_name='split_verify_prompt.jinja2',
|
||||||
operation_name='split_verify_prompt',
|
operation_name='split_verify_prompt',
|
||||||
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
|
|||||||
json_schema=json_schema
|
json_schema=json_schema
|
||||||
)
|
)
|
||||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||||
|
|
||||||
# 使用优化的LLM服务,添加超时保护
|
# 使用优化的LLM服务,添加超时保护
|
||||||
logger.info("Verify: 开始调用 LLM")
|
logger.info("Verify: 开始调用 LLM")
|
||||||
try:
|
try:
|
||||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||||
import asyncio
|
|
||||||
structured = await asyncio.wait_for(
|
with get_db_context() as db_session:
|
||||||
verification_service.call_llm_structured(
|
structured = await asyncio.wait_for(
|
||||||
state=state,
|
verification_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=VerificationResult,
|
system_prompt=system_prompt,
|
||||||
fallback_value={
|
response_model=VerificationResult,
|
||||||
"query": content,
|
fallback_value={
|
||||||
"history": history if isinstance(history, list) else [],
|
"query": content,
|
||||||
"expansion_issue": [],
|
"history": history if isinstance(history, list) else [],
|
||||||
"split_result": "failed",
|
"expansion_issue": [],
|
||||||
"reason": "验证失败或超时"
|
"split_result": "failed",
|
||||||
}
|
"reason": "验证失败或超时"
|
||||||
),
|
}
|
||||||
timeout=150.0 # 150秒超时
|
),
|
||||||
)
|
timeout=150.0 # 150秒超时
|
||||||
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||||
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
|
|||||||
split_result="failed",
|
split_result="failed",
|
||||||
reason="LLM调用超时"
|
reason="LLM调用超时"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await Verify_prompt(state, structured)
|
result = await Verify_prompt(state, structured)
|
||||||
logger.info("=== Verify 节点执行完成 ===")
|
logger.info("=== Verify 节点执行完成 ===")
|
||||||
return {"verify": result}
|
return {"verify": result}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||||
# 返回失败的验证结果
|
# 返回失败的验证结果
|
||||||
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
|
|||||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.utils.write_tools import write
|
from app.core.memory.agent.utils.write_tools import write
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
|
|||||||
)
|
)
|
||||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||||
|
|
||||||
|
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||||
|
for lang in ["zh", "en"]:
|
||||||
|
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language=lang,
|
||||||
|
)
|
||||||
|
if deleted:
|
||||||
|
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||||
|
|
||||||
write_result = {
|
write_result = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"data": structured_messages,
|
"data": structured_messages,
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""创建并返回 LangGraph 工作流"""
|
"""创建并返回 LangGraph 工作流"""
|
||||||
@@ -49,7 +47,7 @@ async def make_read_graph():
|
|||||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||||
workflow.add_node("Summary", Summary)
|
workflow.add_node("Summary", Summary)
|
||||||
workflow.add_node("Summary_fails", Summary_fails)
|
workflow.add_node("Summary_fails", Summary_fails)
|
||||||
|
|
||||||
# 添加边
|
# 添加边
|
||||||
workflow.add_edge(START, "content_input")
|
workflow.add_edge(START, "content_input")
|
||||||
workflow.add_conditional_edges("content_input", Split_continue)
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
@@ -62,20 +60,20 @@ async def make_read_graph():
|
|||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|
||||||
'''-----'''
|
'''-----'''
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
# 编译工作流
|
# 编译工作流
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"创建工作流失败: {e}")
|
print(f"创建工作流失败: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
print("工作流创建完成")
|
print("工作流创建完成")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "昨天有什么好看的电影"
|
message = "昨天有什么好看的电影"
|
||||||
@@ -92,17 +90,19 @@ async def main():
|
|||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
"end_user_id": end_user_id
|
||||||
|
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"memory_config": memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
summary = ''
|
summary = ''
|
||||||
|
|
||||||
async for update_event in graph.astream(
|
async for update_event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode="updates",
|
stream_mode="updates",
|
||||||
@@ -110,7 +110,7 @@ async def main():
|
|||||||
):
|
):
|
||||||
for node_name, node_data in update_event.items():
|
for node_name, node_data in update_event.items():
|
||||||
print(f"处理节点: {node_name}")
|
print(f"处理节点: {node_name}")
|
||||||
|
|
||||||
# 处理不同Summary节点的返回结构
|
# 处理不同Summary节点的返回结构
|
||||||
if 'Summary' in node_name:
|
if 'Summary' in node_name:
|
||||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||||
@@ -125,23 +125,22 @@ async def main():
|
|||||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||||
if spit_data and spit_data != [] and spit_data != {}:
|
if spit_data and spit_data != [] and spit_data != {}:
|
||||||
_intermediate_outputs.append(spit_data)
|
_intermediate_outputs.append(spit_data)
|
||||||
|
|
||||||
# Problem_Extension 节点
|
# Problem_Extension 节点
|
||||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||||
_intermediate_outputs.append(problem_extension)
|
_intermediate_outputs.append(problem_extension)
|
||||||
|
|
||||||
# Retrieve 节点
|
# Retrieve 节点
|
||||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||||
_intermediate_outputs.extend(retrieve_node)
|
_intermediate_outputs.extend(retrieve_node)
|
||||||
|
|
||||||
# Verify 节点
|
# Verify 节点
|
||||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||||
if verify_n and verify_n != [] and verify_n != {}:
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
_intermediate_outputs.append(verify_n)
|
_intermediate_outputs.append(verify_n)
|
||||||
|
|
||||||
|
|
||||||
# Summary 节点
|
# Summary 节点
|
||||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||||
if summary_n and summary_n != [] and summary_n != {}:
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
@@ -161,17 +160,20 @@ async def main():
|
|||||||
#
|
#
|
||||||
print(f"=== 最终摘要 ===")
|
print(f"=== 最终摘要 ===")
|
||||||
print(summary)
|
print(summary)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
end=time.time()
|
end = time.time()
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
print(f"总耗时: {end-start}s")
|
print(f"总耗时: {end - start}s")
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -82,7 +82,9 @@ async def get_chunked_dialogs(
|
|||||||
pruning_config = PruningConfig(
|
pruning_config = PruningConfig(
|
||||||
pruning_switch=memory_config.pruning_enabled,
|
pruning_switch=memory_config.pruning_enabled,
|
||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=memory_config.pruning_threshold
|
pruning_threshold=memory_config.pruning_threshold,
|
||||||
|
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||||
|
ontology_classes=memory_config.ontology_classes,
|
||||||
)
|
)
|
||||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
class LLMClientPool:
|
|
||||||
"""LLM客户端连接池"""
|
|
||||||
|
|
||||||
def __init__(self, max_size: int = 5):
|
|
||||||
self.max_size = max_size
|
|
||||||
self.pools: Dict[str, asyncio.Queue] = {}
|
|
||||||
self.active_clients: Dict[str, int] = {}
|
|
||||||
|
|
||||||
async def get_client(self, llm_model_id: str):
|
|
||||||
"""获取LLM客户端"""
|
|
||||||
if llm_model_id not in self.pools:
|
|
||||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
|
||||||
self.active_clients[llm_model_id] = 0
|
|
||||||
|
|
||||||
pool = self.pools[llm_model_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试从池中获取客户端
|
|
||||||
client = pool.get_nowait()
|
|
||||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
# 池为空,创建新客户端
|
|
||||||
if self.active_clients[llm_model_id] < self.max_size:
|
|
||||||
db_session = next(get_db())
|
|
||||||
client = get_llm_client_fast(llm_model_id, db_session)
|
|
||||||
self.active_clients[llm_model_id] += 1
|
|
||||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
else:
|
|
||||||
# 等待可用客户端
|
|
||||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
|
||||||
return await pool.get()
|
|
||||||
|
|
||||||
async def return_client(self, llm_model_id: str, client):
|
|
||||||
"""归还LLM客户端到池中"""
|
|
||||||
if llm_model_id in self.pools:
|
|
||||||
try:
|
|
||||||
self.pools[llm_model_id].put_nowait(client)
|
|
||||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# 池已满,丢弃客户端
|
|
||||||
self.active_clients[llm_model_id] -= 1
|
|
||||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
|
||||||
|
|
||||||
# 全局客户端池
|
|
||||||
llm_client_pool = LLMClientPool()
|
|
||||||
@@ -225,5 +225,24 @@ async def write(
|
|||||||
with open(log_file, "a", encoding="utf-8") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||||
|
|
||||||
|
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||||
|
try:
|
||||||
|
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||||
|
|
||||||
|
stats_to_cache = {
|
||||||
|
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||||
|
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||||
|
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||||
|
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||||
|
"temporal_count": 0,
|
||||||
|
}
|
||||||
|
await ActivityStatsCache.set_activity_stats(
|
||||||
|
workspace_id=str(memory_config.workspace_id),
|
||||||
|
stats=stats_to_cache,
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||||
|
except Exception as cache_err:
|
||||||
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
@@ -10,7 +10,7 @@ Classes:
|
|||||||
TemporalSearchParams: Parameters for temporal search queries
|
TemporalSearchParams: Parameters for temporal search queries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -55,17 +55,26 @@ class PruningConfig(BaseModel):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pruning_switch: Enable or disable semantic pruning
|
pruning_switch: Enable or disable semantic pruning
|
||||||
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
|
pruning_scene: Scene name for pruning, either a built-in key
|
||||||
|
('education', 'online_service', 'outbound') or a custom scene_name
|
||||||
|
from ontology_scene table
|
||||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||||
|
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
||||||
|
ontology_classes: List of class_name strings from ontology_class table,
|
||||||
|
injected into the prompt when pruning_scene is not a built-in scene
|
||||||
"""
|
"""
|
||||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||||
pruning_scene: str = Field(
|
pruning_scene: str = Field(
|
||||||
"education",
|
"education",
|
||||||
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
|
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
||||||
)
|
)
|
||||||
pruning_threshold: float = Field(
|
pruning_threshold: float = Field(
|
||||||
0.5, ge=0.0, le=0.9,
|
0.5, ge=0.0, le=0.9,
|
||||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||||
|
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||||
|
ontology_classes: Optional[List[str]] = Field(
|
||||||
|
None, description="Class names from ontology_class table for custom scenes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TemporalSearchParams(BaseModel):
|
class TemporalSearchParams(BaseModel):
|
||||||
|
|||||||
@@ -86,19 +86,26 @@ class SemanticPruner:
|
|||||||
self._detailed_prune_logging = True # 是否启用详细日志
|
self._detailed_prune_logging = True # 是否启用详细日志
|
||||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||||
|
|
||||||
# 加载场景特定配置
|
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||||
self.config.pruning_scene,
|
self.config.pruning_scene,
|
||||||
fallback_to_generic=True
|
fallback_to_generic=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查场景是否有专门支持
|
# 判断是否为内置专门场景
|
||||||
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||||
if is_supported:
|
|
||||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
|
# 自定义场景的本体类型列表(用于注入提示词)
|
||||||
|
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||||
|
|
||||||
|
if self._is_builtin_scene:
|
||||||
|
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||||
else:
|
else:
|
||||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
|
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||||
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
|
if self._ontology_classes:
|
||||||
|
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||||
|
else:
|
||||||
|
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||||
|
|
||||||
# Load Jinja2 template
|
# Load Jinja2 template
|
||||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||||
@@ -424,12 +431,16 @@ class SemanticPruner:
|
|||||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||||
|
|
||||||
rendered = self.template.render(
|
rendered = self.template.render(
|
||||||
pruning_scene=self.config.pruning_scene,
|
pruning_scene=self.config.pruning_scene,
|
||||||
|
is_builtin_scene=self._is_builtin_scene,
|
||||||
|
ontology_classes=self._ontology_classes,
|
||||||
dialog_text=dialog_text,
|
dialog_text=dialog_text,
|
||||||
language=self.language
|
language=self.language
|
||||||
)
|
)
|
||||||
log_template_rendering("extracat_Pruning.jinja2", {
|
log_template_rendering("extracat_Pruning.jinja2", {
|
||||||
"pruning_scene": self.config.pruning_scene,
|
"pruning_scene": self.config.pruning_scene,
|
||||||
|
"is_builtin_scene": self._is_builtin_scene,
|
||||||
|
"ontology_classes_count": len(self._ontology_classes),
|
||||||
"language": self.language
|
"language": self.language
|
||||||
})
|
})
|
||||||
log_prompt_rendering("pruning-extract", rendered)
|
log_prompt_rendering("pruning-extract", rendered)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{#
|
{#
|
||||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||||
输入:pruning_scene, dialog_text
|
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||||
- is_related: bool,是否与所选场景相关
|
- is_related: bool,是否与所选场景相关
|
||||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||||
@@ -16,7 +16,8 @@
|
|||||||
- 仅输出上述键;避免多余解释或字段。
|
- 仅输出上述键;避免多余解释或字段。
|
||||||
#}
|
#}
|
||||||
|
|
||||||
{% set scene_instructions = {
|
{# ── 内置场景的固定说明 ── #}
|
||||||
|
{% set builtin_scene_instructions = {
|
||||||
'education': {
|
'education': {
|
||||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||||
@@ -31,16 +32,40 @@
|
|||||||
}
|
}
|
||||||
} %}
|
} %}
|
||||||
|
|
||||||
{% set scene_key = pruning_scene %}
|
{# ── 确定最终使用的场景说明 ── #}
|
||||||
{% if scene_key not in scene_instructions %}
|
{% if is_builtin_scene %}
|
||||||
{% set scene_key = 'education' %}
|
{# 内置专门场景:使用固定说明 #}
|
||||||
|
{% set scene_key = pruning_scene %}
|
||||||
|
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||||
|
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||||
|
{% set custom_types_str = '' %}
|
||||||
|
{% else %}
|
||||||
|
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||||
|
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||||
|
{% if language == 'en' %}
|
||||||
|
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||||
|
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||||
|
{% else %}
|
||||||
|
{% set custom_types_str = ontology_classes | join('、') %}
|
||||||
|
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||||
|
{% endif %}
|
||||||
|
{% else %}
|
||||||
|
{# 无本体类型时退化为通用说明 #}
|
||||||
|
{% if language == 'en' %}
|
||||||
|
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||||
|
{% else %}
|
||||||
|
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||||
|
{% endif %}
|
||||||
|
{% set custom_types_str = '' %}
|
||||||
|
{% endif %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
|
||||||
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||||
场景说明:{{ instruction }}
|
场景说明:{{ instruction }}
|
||||||
|
{% if not is_builtin_scene and custom_types_str %}
|
||||||
|
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
对话全文:
|
对话全文:
|
||||||
"""
|
"""
|
||||||
@@ -60,6 +85,9 @@
|
|||||||
{% else %}
|
{% else %}
|
||||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||||
Scenario Description: {{ instruction }}
|
Scenario Description: {{ instruction }}
|
||||||
|
{% if not is_builtin_scene and custom_types_str %}
|
||||||
|
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
Full Dialogue:
|
Full Dialogue:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -8,34 +8,60 @@ from typing import Any
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
from app.core.workflow.adapters.errors import (
|
||||||
|
UnsupportVariableType,
|
||||||
|
UnknowModelWarning,
|
||||||
|
ExceptionDefineition,
|
||||||
ExceptionType
|
ExceptionType
|
||||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
)
|
||||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
|
||||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
from app.core.workflow.nodes.configs import (
|
||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
StartNodeConfig,
|
||||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
LLMNodeConfig,
|
||||||
|
AssignerNodeConfig,
|
||||||
|
CodeNodeConfig,
|
||||||
|
LoopNodeConfig,
|
||||||
|
IterationNodeConfig,
|
||||||
|
EndNodeConfig,
|
||||||
|
HttpRequestNodeConfig,
|
||||||
|
IfElseNodeConfig,
|
||||||
|
JinjaRenderNodeConfig,
|
||||||
|
KnowledgeRetrievalNodeConfig,
|
||||||
|
NoteNodeConfig,
|
||||||
|
ParameterExtractorNodeConfig,
|
||||||
|
QuestionClassifierNodeConfig,
|
||||||
|
VariableAggregatorNodeConfig
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.cycle_graph.config import (
|
||||||
|
ConditionDetail as LoopConditionDetail,
|
||||||
|
ConditionsConfig,
|
||||||
CycleVariable
|
CycleVariable
|
||||||
from app.core.workflow.nodes.end import EndNodeConfig
|
)
|
||||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
from app.core.workflow.nodes.enums import (
|
||||||
HttpContentType, HttpErrorHandle
|
ValueInputType,
|
||||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
ComparisonOperator,
|
||||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
AssignmentOperator,
|
||||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
HttpAuthType,
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
HttpContentType,
|
||||||
|
HttpErrorHandle,
|
||||||
|
NodeType
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.http_request.config import (
|
||||||
|
HttpAuthConfig,
|
||||||
|
HttpContentTypeConfig,
|
||||||
|
HttpFormData,
|
||||||
|
HttpTimeOutConfig,
|
||||||
|
HttpRetryConfig,
|
||||||
|
HttpErrorDefaultTamplete,
|
||||||
|
HttpErrorHandleConfig
|
||||||
|
)
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
|
||||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
|
||||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
|
||||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
|
||||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
|
|
||||||
|
|
||||||
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.CONFIG_CONVERT_MAP = {
|
self.CONFIG_CONVERT_MAP = {
|
||||||
"start": self.convert_start_node_config,
|
NodeType.START: self.convert_start_node_config,
|
||||||
"llm": self.convert_llm_node_config,
|
NodeType.LLM: self.convert_llm_node_config,
|
||||||
"answer": self.convert_end_node_config,
|
NodeType.END: self.convert_end_node_config,
|
||||||
"if-else": self.convert_if_else_node_config,
|
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||||
"loop": self.convert_loop_node_config,
|
NodeType.LOOP: self.convert_loop_node_config,
|
||||||
"iteration": self.convert_iteration_node_config,
|
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||||
"assigner": self.convert_assigner_node_config,
|
NodeType.ASSIGNER: self.convert_assigner_node_config,
|
||||||
"code": self.convert_code_node_config,
|
NodeType.CODE: self.convert_code_node_config,
|
||||||
"http-request": self.convert_http_node_config,
|
NodeType.HTTP_REQUEST: self.convert_http_node_config,
|
||||||
"template-transform": self.convert_jinja_render_node_config,
|
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
|
||||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
|
||||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
|
||||||
"question-classifier": self.convert_question_classifier_node_config,
|
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
|
||||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||||
"tool": self.convert_tool_node_config,
|
NodeType.TOOL: self.convert_tool_node_config,
|
||||||
"loop-start": lambda x: {},
|
NodeType.NOTES: self.convert_notes_config,
|
||||||
"iteration-start": lambda x: {},
|
NodeType.CYCLE_START: lambda x: {},
|
||||||
"loop-end": lambda x: {},
|
NodeType.BREAK: lambda x: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_node_convert(self, node_type):
|
def get_node_convert(self, node_type):
|
||||||
@@ -129,11 +155,11 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_file(var):
|
def _convert_file(var):
|
||||||
pass
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_array_file(var):
|
def _convert_array_file(var):
|
||||||
pass
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def variable_type_map(source_type) -> VariableType | None:
|
def variable_type_map(source_type) -> VariableType | None:
|
||||||
@@ -185,6 +211,9 @@ class DifyConverter(BaseConverter):
|
|||||||
"not empty": ComparisonOperator.NOT_EMPTY,
|
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||||
"start with": ComparisonOperator.START_WITH,
|
"start with": ComparisonOperator.START_WITH,
|
||||||
"end with": ComparisonOperator.END_WITH,
|
"end with": ComparisonOperator.END_WITH,
|
||||||
|
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||||
|
"exists": ComparisonOperator.NOT_EMPTY,
|
||||||
|
"not exists": ComparisonOperator.EMPTY
|
||||||
}
|
}
|
||||||
return operator_map.get(operator, operator)
|
return operator_map.get(operator, operator)
|
||||||
|
|
||||||
@@ -198,7 +227,7 @@ class DifyConverter(BaseConverter):
|
|||||||
"over-write": AssignmentOperator.COVER,
|
"over-write": AssignmentOperator.COVER,
|
||||||
"remove-last": AssignmentOperator.REMOVE_LAST,
|
"remove-last": AssignmentOperator.REMOVE_LAST,
|
||||||
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
||||||
|
"set": AssignmentOperator.ASSIGN,
|
||||||
}
|
}
|
||||||
return operator_map.get(operator, operator)
|
return operator_map.get(operator, operator)
|
||||||
|
|
||||||
@@ -267,10 +296,10 @@ class DifyConverter(BaseConverter):
|
|||||||
type=var_type,
|
type=var_type,
|
||||||
required=var["required"],
|
required=var["required"],
|
||||||
default=self.convert_variable_type(
|
default=self.convert_variable_type(
|
||||||
var_type, var["default"]
|
var_type, var.get("default")
|
||||||
),
|
),
|
||||||
description=var["label"],
|
description=var["label"],
|
||||||
max_length=var.get("max_length"),
|
max_length=var.get("max_length", 50),
|
||||||
)
|
)
|
||||||
start_vars.append(var_def)
|
start_vars.append(var_def)
|
||||||
result = StartNodeConfig.model_construct(
|
result = StartNodeConfig.model_construct(
|
||||||
@@ -333,7 +362,7 @@ class DifyConverter(BaseConverter):
|
|||||||
MessageConfig(
|
MessageConfig(
|
||||||
role="user",
|
role="user",
|
||||||
content=self.trans_variable_format(
|
content=self.trans_variable_format(
|
||||||
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
|
node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -364,7 +393,7 @@ class DifyConverter(BaseConverter):
|
|||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
cases = []
|
cases = []
|
||||||
for case in node_data["cases"]:
|
for case in node_data["cases"]:
|
||||||
case_id = case["id"]
|
case_id = case.get("id") or case.get("case_id")
|
||||||
logical_operator = case["logical_operator"]
|
logical_operator = case["logical_operator"]
|
||||||
conditions = []
|
conditions = []
|
||||||
for condition in case["conditions"]:
|
for condition in case["conditions"]:
|
||||||
@@ -540,7 +569,8 @@ class DifyConverter(BaseConverter):
|
|||||||
] = self.trans_variable_format(content["value"])
|
] = self.trans_variable_format(content["value"])
|
||||||
else:
|
else:
|
||||||
if node_data["body"]["data"]:
|
if node_data["body"]["data"]:
|
||||||
body_content = node_data["body"]["data"][0]["value"]
|
body_content = (node_data["body"]["data"][0].get("value") or
|
||||||
|
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
||||||
else:
|
else:
|
||||||
body_content = ""
|
body_content = ""
|
||||||
|
|
||||||
@@ -612,7 +642,7 @@ class DifyConverter(BaseConverter):
|
|||||||
),
|
),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
verify_ssl=node_data["ssl_verify"],
|
verify_ssl=node_data.get("ssl_verify", False),
|
||||||
timeouts=HttpTimeOutConfig.model_construct(
|
timeouts=HttpTimeOutConfig.model_construct(
|
||||||
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||||
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||||
@@ -696,7 +726,7 @@ class DifyConverter(BaseConverter):
|
|||||||
group_variables = {}
|
group_variables = {}
|
||||||
group_type = {}
|
group_type = {}
|
||||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||||
group_variables["output"] = [
|
group_variables = [
|
||||||
self._process_list_variable_litearl(variable)
|
self._process_list_variable_litearl(variable)
|
||||||
for variable in node_data["variables"]
|
for variable in node_data["variables"]
|
||||||
]
|
]
|
||||||
@@ -728,3 +758,16 @@ class DifyConverter(BaseConverter):
|
|||||||
detail=f"Please reconfigure the tool node.",
|
detail=f"Please reconfigure the tool node.",
|
||||||
))
|
))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_notes_config(node: dict):
|
||||||
|
node_data = node["data"]
|
||||||
|
result = NoteNodeConfig.model_construct(
|
||||||
|
author=node_data.get("author", ""),
|
||||||
|
text=node_data.get("text", ""),
|
||||||
|
width=node_data.get("width", 80),
|
||||||
|
height=node_data.get("height", 80),
|
||||||
|
theme=node_data.get("theme", "blue"),
|
||||||
|
show_author=node_data.get("showAuthor", True)
|
||||||
|
).model_dump()
|
||||||
|
return result
|
||||||
|
|||||||
@@ -44,12 +44,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||||
"tool": NodeType.TOOL
|
"tool": NodeType.TOOL,
|
||||||
|
"": NodeType.NOTES
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: dict[str, Any]):
|
def __init__(self, config: dict[str, Any]):
|
||||||
DifyConverter.__init__(self)
|
DifyConverter.__init__(self)
|
||||||
BasePlatformAdapter.__init__(self, config)
|
BasePlatformAdapter.__init__(self, config)
|
||||||
|
|
||||||
def get_metadata(self) -> PlatformMetadata:
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
@@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
def map_node_type(self, platform_node_type) -> str:
|
def map_node_type(self, platform_node_type) -> NodeType:
|
||||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -83,6 +84,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
|
if self.config.get("app", {}).get("mode") == "workflow":
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.PLATFORM,
|
||||||
|
detail="workflow mode is not supported"
|
||||||
|
))
|
||||||
|
return False
|
||||||
|
|
||||||
for node in self.origin_nodes:
|
for node in self.origin_nodes:
|
||||||
if not self._valid_nodes(node):
|
if not self._valid_nodes(node):
|
||||||
@@ -134,6 +141,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
for node in self.origin_nodes:
|
for node in self.origin_nodes:
|
||||||
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
|
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
|
||||||
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
|
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
|
||||||
|
elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||||
|
self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output"
|
||||||
|
|
||||||
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||||
for node in self.origin_nodes:
|
for node in self.origin_nodes:
|
||||||
@@ -154,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
try:
|
try:
|
||||||
|
node_type = self.map_node_type(node_data["type"])
|
||||||
return NodeDefinition(
|
return NodeDefinition(
|
||||||
id=node["id"],
|
id=node["id"],
|
||||||
type=self.map_node_type(node_data["type"]),
|
type=node_type,
|
||||||
name=node_data.get("title"),
|
name=node_data.get("title") or "notes",
|
||||||
cycle=node.get("parentId"),
|
cycle=node.get("parentId"),
|
||||||
description=None,
|
description=None,
|
||||||
config=self._convert_node_config(node),
|
config=self._convert_node_config(node_type, node),
|
||||||
position={
|
position={
|
||||||
"x": node["position"]["x"],
|
"x": node["position"]["x"],
|
||||||
"y": node["position"]["y"]
|
"y": node["position"]["y"]
|
||||||
@@ -174,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||||
|
|
||||||
def _convert_node_config(self, node: dict):
|
def _convert_node_config(self, node_type: NodeType, node: dict):
|
||||||
node_data = node["data"]
|
|
||||||
node_type = node_data["type"]
|
|
||||||
try:
|
try:
|
||||||
|
node_data = node["data"]
|
||||||
converter = self.get_node_convert(node_type)
|
converter = self.get_node_convert(node_type)
|
||||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefineition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
detail=f"node type {node_type} is unsupported",
|
detail=f"node type {node_data.get('type')} is unsupported",
|
||||||
))
|
))
|
||||||
return converter(node)
|
return converter(node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -201,16 +210,15 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
source = edge["source"]
|
source = edge["source"]
|
||||||
target = edge["target"]
|
target = edge["target"]
|
||||||
edge_id = edge["id"]
|
|
||||||
label = None
|
label = None
|
||||||
if source in self.branch_node_cache:
|
if source in self.branch_node_cache:
|
||||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
case_id = edge["sourceHandle"]
|
||||||
if case_id == "false":
|
if case_id == "false":
|
||||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
|
||||||
else:
|
else:
|
||||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||||
if source in self.error_branch_node_cache:
|
if source in self.error_branch_node_cache:
|
||||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
case_id = edge["sourceHandle"]
|
||||||
if case_id == "source":
|
if case_id == "source":
|
||||||
label = "SUCCESS"
|
label = "SUCCESS"
|
||||||
else:
|
else:
|
||||||
@@ -235,6 +243,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
name=variable["name"],
|
name=variable["name"],
|
||||||
default=variable["value"],
|
default=variable["value"],
|
||||||
type=self.variable_type_map(variable["value_type"]),
|
type=self.variable_type_map(variable["value_type"]),
|
||||||
|
description=variable.get("description")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefineition(
|
||||||
@@ -248,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||||
return ExecutionConfig()
|
return ExecutionConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -292,6 +292,8 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
|
if node_type == NodeType.NOTES:
|
||||||
|
continue
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
cycle_node = node.get("cycle")
|
cycle_node = node.get("cycle")
|
||||||
if cycle_node:
|
if cycle_node:
|
||||||
@@ -320,7 +322,7 @@ class GraphBuilder:
|
|||||||
# Used later to determine which branch to take based on the node's output
|
# Used later to determine which branch to take based on the node's output
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
|
|||||||
@@ -158,18 +158,36 @@ class WorkflowExecutor:
|
|||||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||||
|
|
||||||
# Append messages for user and assistant
|
# Append messages for user and assistant
|
||||||
result["messages"].extend(
|
if input_data.get("files"):
|
||||||
[
|
result["messages"].extend(
|
||||||
{
|
[
|
||||||
"role": "user",
|
{
|
||||||
"content": input_data.get("message", '')
|
"role": "user",
|
||||||
},
|
"content": input_data.get("message", '')
|
||||||
{
|
},
|
||||||
"role": "assistant",
|
{
|
||||||
"content": full_content
|
"role": "user",
|
||||||
}
|
"content": input_data.get("files")
|
||||||
]
|
},
|
||||||
)
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
# Calculate elapsed time
|
# Calculate elapsed time
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
@@ -308,18 +326,36 @@ class WorkflowExecutor:
|
|||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
# Append messages for user and assistant
|
# Append messages for user and assistant
|
||||||
result["messages"].extend(
|
if input_data.get("files"):
|
||||||
[
|
result["messages"].extend(
|
||||||
{
|
[
|
||||||
"role": "user",
|
{
|
||||||
"content": input_data.get("message", '')
|
"role": "user",
|
||||||
},
|
"content": input_data.get("message", '')
|
||||||
{
|
},
|
||||||
"role": "assistant",
|
{
|
||||||
"content": full_content
|
"role": "user",
|
||||||
}
|
"content": input_data.get("files")
|
||||||
]
|
},
|
||||||
)
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result["messages"].extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input_data.get("message", '')
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Workflow execution completed (streaming), "
|
f"Workflow execution completed (streaming), "
|
||||||
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
|
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
|||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.models import AppRelease
|
from app.models import AppRelease
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||||
"""准备 Agent(公共逻辑)
|
"""准备 Agent(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
|
|||||||
if not agent_id:
|
if not agent_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||||
|
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
release = db.query(AppRelease).filter(
|
release = db.query(AppRelease).filter(
|
||||||
AppRelease.id == agent_id
|
AppRelease.id == agent_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
draft_service = AgentRunService(db)
|
|
||||||
|
|
||||||
return draft_service, release, message
|
return release, message
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""非流式执行
|
"""非流式执行
|
||||||
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
状态更新字典
|
状态更新字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||||
|
with get_db_context() as db:
|
||||||
# 执行 Agent(非流式)
|
draft_service = AgentRunService(db)
|
||||||
result = await draft_service.run(
|
|
||||||
agent_config=release.config,
|
# 执行 Agent(非流式)
|
||||||
model_config=None,
|
result = await draft_service.run(
|
||||||
message=message,
|
agent_config=release.config,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
model_config=None,
|
||||||
user_id=state.get("user_id"),
|
message=message,
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
)
|
user_id=state.get("user_id"),
|
||||||
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
|
)
|
||||||
|
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
|
|
||||||
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
|
|||||||
Yields:
|
Yields:
|
||||||
流式事件字典
|
流式事件字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
with get_db_context() as db:
|
||||||
|
draft_service = AgentRunService(db)
|
||||||
# 执行 Agent(流式)
|
# 执行 Agent(流式)
|
||||||
async for chunk in draft_service.run_stream(
|
async for chunk in draft_service.run_stream(
|
||||||
agent_config=release.config,
|
agent_config=release.config,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
user_id=state.get("user_id"),
|
user_id=state.get("user_id"),
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
):
|
):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
content = chunk.get("content", "")
|
content = chunk.get("content", "")
|
||||||
full_response += content
|
full_response += content
|
||||||
|
|
||||||
# 流式返回每个 chunk
|
# 流式返回每个 chunk
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"type": "chunk",
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"full_content": full_response,
|
"full_content": full_response,
|
||||||
"meta_data": chunk.get("meta_data", {})
|
"meta_data": chunk.get("meta_data", {})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||||
|
|
||||||
|
|||||||
@@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel):
|
|||||||
- tags: 节点标签(用于分类和搜索)
|
- tags: 节点标签(用于分类和搜索)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str | None = Field(
|
# name: str | None = Field(
|
||||||
default=None,
|
# default=None,
|
||||||
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
# description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
description: str | None = Field(
|
# description: str | None = Field(
|
||||||
default=None,
|
# default=None,
|
||||||
description="节点描述,说明节点的作用"
|
# description="节点描述,说明节点的作用"
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
tags: list[str] = Field(
|
# tags: list[str] = Field(
|
||||||
default_factory=list,
|
# default_factory=list,
|
||||||
description="节点标签,用于分类和搜索"
|
# description="节点标签,用于分类和搜索"
|
||||||
)
|
# )
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Pydantic 配置"""
|
"""Pydantic 配置"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
|
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
@@ -617,17 +618,31 @@ class BaseNode(ABC):
|
|||||||
return variable_pool.has(selector)
|
return variable_pool.has(selector)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
|
async def process_message(
|
||||||
|
provider: str,
|
||||||
|
is_omni: bool,
|
||||||
|
content: str | dict | FileObject,
|
||||||
|
enable_file=False
|
||||||
|
) -> list | str | None:
|
||||||
|
if isinstance(content, dict):
|
||||||
|
content = FileObject(
|
||||||
|
type=content.get("type"),
|
||||||
|
url=content.get("url"),
|
||||||
|
transfer_method=content.get("transfer_method"),
|
||||||
|
origin_file_type=content.get("origin_file_type"),
|
||||||
|
file_id=content.get("file_id"),
|
||||||
|
is_file=True
|
||||||
|
)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
if enable_file:
|
if enable_file:
|
||||||
return {"text": content}
|
return [{"type": "text", "text": content}]
|
||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(provider):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[provider]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, provider)
|
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodel_service.process_files(
|
||||||
[FileInput.model_construct(
|
[FileInput.model_construct(
|
||||||
type=content.type,
|
type=content.type,
|
||||||
@@ -637,10 +652,9 @@ class BaseNode(ABC):
|
|||||||
upload_file_id=content.file_id
|
upload_file_id=content.file_id
|
||||||
)]
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message[0]
|
content.content_cache[provider] = message
|
||||||
return message[0]
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||||
|
|
||||||
@@ -658,3 +672,12 @@ class BaseNode(ABC):
|
|||||||
elif isinstance(content, str):
|
elif isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_balance(model_config: ModelConfig) -> ModelApiKey:
|
||||||
|
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||||
|
if not api_keys:
|
||||||
|
raise ValueError("No active API keys available for model")
|
||||||
|
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||||
|
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||||
|
return api_keys[0]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
|||||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||||
|
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础类
|
# 基础类
|
||||||
@@ -47,5 +48,6 @@ __all__ = [
|
|||||||
"ToolNodeConfig",
|
"ToolNodeConfig",
|
||||||
"MemoryReadNodeConfig",
|
"MemoryReadNodeConfig",
|
||||||
"MemoryWriteNodeConfig",
|
"MemoryWriteNodeConfig",
|
||||||
"CodeNodeConfig"
|
"CodeNodeConfig",
|
||||||
|
"NoteNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class NodeType(StrEnum):
|
|||||||
MEMORY_WRITE = "memory-write"
|
MEMORY_WRITE = "memory-write"
|
||||||
|
|
||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||||
|
|||||||
@@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||||
"""
|
"""
|
||||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||||
|
if not self.typed_config.knowledge_bases:
|
||||||
|
return []
|
||||||
query = self._render_template(self.typed_config.query, variable_pool)
|
query = self._render_template(self.typed_config.query, variable_pool)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
|
|||||||
@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
|
|||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
# 在 Session 关闭前提取所有需要的数据
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
extra_params=extra_params
|
extra_params=extra_params,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
@@ -151,39 +153,53 @@ class LLMNode(BaseNode):
|
|||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": content
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": content
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": content
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": content
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.append(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
messages[-1]['content'] = messages[-1]["content"] + file_content
|
||||||
else:
|
else:
|
||||||
messages.append({"role": "user", "content": file_content})
|
messages.append({"role": "user", "content": file_content})
|
||||||
|
|
||||||
if self.typed_config.memory.enable:
|
if self.typed_config.memory.enable:
|
||||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
history_message = []
|
||||||
|
for message in state["messages"][-self.typed_config.memory.window_size:]:
|
||||||
|
if isinstance(message["content"], list):
|
||||||
|
file_content = []
|
||||||
|
for file in message["content"]:
|
||||||
|
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||||
|
if content:
|
||||||
|
file_content.extend(content)
|
||||||
|
history_message.append(
|
||||||
|
{"role": message["role"], "content": file_content}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||||
|
history_message.append(message)
|
||||||
|
messages = messages[:-1] + history_message + messages[-1:]
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
else:
|
else:
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
|
|||||||
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NoteNodeConfig(BaseNodeConfig):
|
||||||
|
author: str = Field(default="", description="author")
|
||||||
|
text: str = Field(default="", description="note content")
|
||||||
|
width: int = Field(default=80)
|
||||||
|
height: int = Field(default=80)
|
||||||
|
theme: str = Field(default="blue")
|
||||||
|
show_author: bool = Field(default=True)
|
||||||
@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
base_url = api_config.api_base
|
base_url = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
return RedBearLLM(
|
return RedBearLLM(
|
||||||
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class WorkflowValidator:
|
|||||||
errors.append("工作流必须至少有一个 end 节点")
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
# 3. 验证节点 ID 唯一性
|
# 3. 验证节点 ID 唯一性
|
||||||
node_ids = [n.get("id") for n in nodes]
|
node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
|
||||||
if len(node_ids) != len(set(node_ids)):
|
if len(node_ids) != len(set(node_ids)):
|
||||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
|
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
|
|||||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSourceChannel(StrEnum):
|
||||||
|
"""MCP来源渠道枚举"""
|
||||||
|
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
|
||||||
|
MODELSCOPE = "modelscope" # ModelScope
|
||||||
|
TOKENFLUX = "tokenflux" # TokenFlux
|
||||||
|
LANGENG = "langeng" # 蓝耕科技
|
||||||
|
AI_302 = "302ai" # 302.AI
|
||||||
|
MCP_ROUTER = "mcp_router" # MCP Router
|
||||||
|
SELF_HOSTED = "self_hosted" # 自建
|
||||||
|
|
||||||
|
|
||||||
class MCPToolConfig(Base):
|
class MCPToolConfig(Base):
|
||||||
"""MCP工具配置模型"""
|
"""MCP工具配置模型"""
|
||||||
__tablename__ = "mcp_tool_configs"
|
__tablename__ = "mcp_tool_configs"
|
||||||
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
|
|||||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||||
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
||||||
|
|
||||||
|
# 来源渠道
|
||||||
|
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
|
||||||
|
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
|
||||||
|
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
|
||||||
|
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
|
||||||
|
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
|
||||||
|
|
||||||
# 服务状态
|
# 服务状态
|
||||||
last_health_check = Column(DateTime)
|
last_health_check = Column(DateTime)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import List, Optional
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from app.models.app_model import App
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
|
from app.models.app_model import App
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -35,11 +36,27 @@ class AppRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_apps_by_name(self, app_name: str, app_type: str, workspace_id: uuid.UUID) -> List[App]:
|
||||||
|
try:
|
||||||
|
stmt = select(App).where(
|
||||||
|
App.name == app_name,
|
||||||
|
App.workspace_id == workspace_id,
|
||||||
|
App.type == app_type,
|
||||||
|
App.is_active.is_(True),
|
||||||
|
)
|
||||||
|
apps = self.db.execute(stmt).scalars().all()
|
||||||
|
return list(apps)
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"查询名称 {app_name} 应用异常: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||||
"""根据工作空间ID查询应用"""
|
"""根据工作空间ID查询应用"""
|
||||||
repo = AppRepository(db)
|
repo = AppRepository(db)
|
||||||
return repo.get_apps_by_workspace_id(workspace_id)
|
return repo.get_apps_by_workspace_id(workspace_id)
|
||||||
|
|
||||||
|
|
||||||
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||||
"""根据工作空间ID查询应用"""
|
"""根据工作空间ID查询应用"""
|
||||||
repo = AppRepository(db)
|
repo = AppRepository(db)
|
||||||
|
|||||||
@@ -5,13 +5,22 @@ Implicit Emotions Storage Repository
|
|||||||
事务由调用方控制,仓储层只使用 flush/refresh
|
事务由调用方控制,仓储层只使用 flush/refresh
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, date, timezone, timedelta
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from typing import Optional, Generator
|
from typing import Generator, Optional
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy import select, not_, exists
|
|
||||||
|
class TimeFilterUnavailableError(Exception):
|
||||||
|
"""redis_client 不可用,无法执行时间轴筛选。
|
||||||
|
|
||||||
|
调用方捕获此异常后可选择回退到 get_all_user_ids 进行全量处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from sqlalchemy import exists, not_, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -111,6 +120,88 @@ class ImplicitEmotionsStorageRepository:
|
|||||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def get_users_needing_refresh(self, redis_client: redis.StrictRedis, batch_size: int = 100) -> Generator[str, None, None]:
|
||||||
|
"""分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。
|
||||||
|
|
||||||
|
筛选逻辑:
|
||||||
|
- 查询 implicit_emotions_storage 中所有用户的 end_user_id 和 updated_at
|
||||||
|
- 从 Redis 读取 write_message:last_done:{end_user_id} 的时间戳
|
||||||
|
- 若 Redis 中无记录(该用户从未写入过记忆),跳过
|
||||||
|
- 若 last_done > updated_at,说明上次刷新后又有新记忆写入,需要刷新
|
||||||
|
- 若 last_done <= updated_at,说明已是最新,跳过
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB)
|
||||||
|
batch_size: 每批次加载的数量
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeFilterUnavailableError: redis_client 为 None 时抛出,调用方可捕获并回退到 get_all_user_ids
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
需要刷新的用户ID字符串
|
||||||
|
"""
|
||||||
|
if redis_client is None:
|
||||||
|
raise TimeFilterUnavailableError("redis_client 不可用,无法执行时间轴筛选")
|
||||||
|
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(ImplicitEmotionsStorage.end_user_id, ImplicitEmotionsStorage.updated_at)
|
||||||
|
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||||
|
.limit(batch_size)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
batch = self.db.execute(stmt).all()
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量获取当前批次所有用户的 last_done 时间戳(一次网络往返)
|
||||||
|
keys = [f"write_message:last_done:{end_user_id}" for end_user_id, _ in batch]
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_values = redis_client.mget(keys)
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Redis mget 操作失败: {e},当前批次降级为处理所有用户",
|
||||||
|
extra={"offset": offset, "batch_size": len(batch)}
|
||||||
|
)
|
||||||
|
# Redis 操作失败,降级为返回当前批次所有用户
|
||||||
|
yield from (end_user_id for end_user_id, _ in batch)
|
||||||
|
offset += batch_size
|
||||||
|
continue
|
||||||
|
|
||||||
|
for (end_user_id, updated_at), raw in zip(batch, raw_values):
|
||||||
|
if raw is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
CST = timezone(timedelta(hours=8))
|
||||||
|
last_done = datetime.fromisoformat(raw)
|
||||||
|
# last_done 写入时已是 CST naive,直接使用,无需转换
|
||||||
|
if last_done.tzinfo is not None:
|
||||||
|
last_done = last_done.astimezone(CST).replace(tzinfo=None)
|
||||||
|
|
||||||
|
if updated_at is None:
|
||||||
|
yield end_user_id
|
||||||
|
continue
|
||||||
|
# updated_at 数据库存的是 UTC naive,转为 CST naive 再比较
|
||||||
|
if updated_at.tzinfo is None:
|
||||||
|
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
||||||
|
else:
|
||||||
|
updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None)
|
||||||
|
|
||||||
|
if last_done > updated_at_cst:
|
||||||
|
yield end_user_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
|
||||||
|
|
||||||
|
offset += batch_size
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_users_needing_refresh 分批查询失败: offset={offset}, error={e}")
|
||||||
|
break
|
||||||
|
|
||||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||||
|
|
||||||
@@ -124,7 +215,8 @@ class ImplicitEmotionsStorageRepository:
|
|||||||
Yields:
|
Yields:
|
||||||
用户ID字符串
|
用户ID字符串
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import cast, String as SAString
|
from sqlalchemy import String as SAString
|
||||||
|
from sqlalchemy import cast
|
||||||
CST = timezone(timedelta(hours=8))
|
CST = timezone(timedelta(hours=8))
|
||||||
now_cst = datetime.now(CST)
|
now_cst = datetime.now(CST)
|
||||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ class MemoryConfigRepository:
|
|||||||
config_desc=params.config_desc,
|
config_desc=params.config_desc,
|
||||||
workspace_id=params.workspace_id,
|
workspace_id=params.workspace_id,
|
||||||
scene_id=params.scene_id,
|
scene_id=params.scene_id,
|
||||||
|
pruning_scene=params.pruning_scene,
|
||||||
llm_id=params.llm_id,
|
llm_id=params.llm_id,
|
||||||
embedding_id=params.embedding_id,
|
embedding_id=params.embedding_id,
|
||||||
rerank_id=params.rerank_id,
|
rerank_id=params.rerank_id,
|
||||||
|
|||||||
@@ -374,7 +374,7 @@ class OntologySceneRepository:
|
|||||||
|
|
||||||
count = self.db.query(OntologyScene).filter(
|
count = self.db.query(OntologyScene).filter(
|
||||||
OntologyScene.scene_id == scene_id,
|
OntologyScene.scene_id == scene_id,
|
||||||
OntologyScene.workspace_id == workspace_id
|
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
is_owner = count > 0
|
is_owner = count > 0
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
from app.models.user_model import User
|
|
||||||
from typing import List, Optional
|
|
||||||
import uuid
|
import uuid
|
||||||
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
from typing import List, Optional
|
||||||
from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole
|
||||||
|
from app.schemas.workspace_schema import WorkspaceCreate
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -19,7 +22,7 @@ class WorkspaceRepository:
|
|||||||
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace:
|
||||||
"""创建工作空间"""
|
"""创建工作空间"""
|
||||||
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
|
db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_workspace = Workspace(
|
db_workspace = Workspace(
|
||||||
name=workspace_data.name,
|
name=workspace_data.name,
|
||||||
@@ -34,7 +37,8 @@ class WorkspaceRepository:
|
|||||||
)
|
)
|
||||||
self.db.add(db_workspace)
|
self.db.add(db_workspace)
|
||||||
self.db.flush()
|
self.db.flush()
|
||||||
db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
db_logger.info(
|
||||||
|
f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}")
|
||||||
return db_workspace
|
return db_workspace
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
|
db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}")
|
||||||
@@ -43,7 +47,7 @@ class WorkspaceRepository:
|
|||||||
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
|
def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]:
|
||||||
"""根据ID获取工作空间"""
|
"""根据ID获取工作空间"""
|
||||||
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
|
db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
if workspace:
|
if workspace:
|
||||||
@@ -65,7 +69,7 @@ class WorkspaceRepository:
|
|||||||
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
if workspace:
|
if workspace:
|
||||||
@@ -89,7 +93,7 @@ class WorkspaceRepository:
|
|||||||
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
|
def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]:
|
||||||
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
|
"""获取用户参与的所有工作空间(包括用户创建的和作为成员的)"""
|
||||||
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
|
db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 首先获取用户信息以获取 tenant_id
|
# 首先获取用户信息以获取 tenant_id
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
@@ -97,7 +101,7 @@ class WorkspaceRepository:
|
|||||||
if not user:
|
if not user:
|
||||||
db_logger.warning(f"用户不存在: user_id={user_id}")
|
db_logger.warning(f"用户不存在: user_id={user_id}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if user.is_superuser:
|
if user.is_superuser:
|
||||||
# 超级用户获取对应tenantid所有工作空间
|
# 超级用户获取对应tenantid所有工作空间
|
||||||
workspaces = (
|
workspaces = (
|
||||||
@@ -109,7 +113,7 @@ class WorkspaceRepository:
|
|||||||
)
|
)
|
||||||
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
|
db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}")
|
||||||
return workspaces
|
return workspaces
|
||||||
|
|
||||||
# 获取用户作为成员的工作空间
|
# 获取用户作为成员的工作空间
|
||||||
member_workspaces = (
|
member_workspaces = (
|
||||||
self.db.query(Workspace)
|
self.db.query(Workspace)
|
||||||
@@ -120,7 +124,7 @@ class WorkspaceRepository:
|
|||||||
.order_by(Workspace.updated_at.desc())
|
.order_by(Workspace.updated_at.desc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
|
db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}")
|
||||||
return member_workspaces
|
return member_workspaces
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -130,7 +134,7 @@ class WorkspaceRepository:
|
|||||||
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
|
def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]:
|
||||||
"""获取租户的所有工作空间"""
|
"""获取租户的所有工作空间"""
|
||||||
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
|
db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspaces = (
|
workspaces = (
|
||||||
self.db.query(Workspace)
|
self.db.query(Workspace)
|
||||||
@@ -144,14 +148,32 @@ class WorkspaceRepository:
|
|||||||
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
|
db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
def get_workspaces_by_name(self, tenant_id: uuid.UUID, workspace_name: str) -> List[Workspace]:
|
||||||
|
try:
|
||||||
|
stmt = (
|
||||||
|
select(Workspace)
|
||||||
|
.where(
|
||||||
|
Workspace.tenant_id == tenant_id,
|
||||||
|
Workspace.name == workspace_name,
|
||||||
|
Workspace.is_active.is_(True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
workspaces = self.db.execute(stmt).scalars().all()
|
||||||
|
return list(workspaces)
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"查询工作空间失败: workspace_name={workspace_name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID,
|
||||||
|
role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember:
|
||||||
"""添加工作空间成员"""
|
"""添加工作空间成员"""
|
||||||
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_member = WorkspaceMember(
|
db_member = WorkspaceMember(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
role=role
|
role=role
|
||||||
)
|
)
|
||||||
self.db.add(db_member)
|
self.db.add(db_member)
|
||||||
@@ -165,7 +187,7 @@ class WorkspaceRepository:
|
|||||||
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||||
"""获取工作空间成员"""
|
"""获取工作空间成员"""
|
||||||
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
|
db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
member = self.db.query(WorkspaceMember).filter(
|
member = self.db.query(WorkspaceMember).filter(
|
||||||
WorkspaceMember.user_id == user_id,
|
WorkspaceMember.user_id == user_id,
|
||||||
@@ -173,7 +195,8 @@ class WorkspaceRepository:
|
|||||||
WorkspaceMember.is_active.is_(True),
|
WorkspaceMember.is_active.is_(True),
|
||||||
).first()
|
).first()
|
||||||
if member:
|
if member:
|
||||||
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
db_logger.debug(
|
||||||
|
f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
|
||||||
else:
|
else:
|
||||||
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
|
db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}")
|
||||||
return member
|
return member
|
||||||
@@ -199,7 +222,7 @@ class WorkspaceRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
|
def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember:
|
||||||
"""按成员ID获取工作空间成员,并预加载 user 与 workspace 关系"""
|
"""按成员ID获取工作空间成员,并预加载 user 与 workspace 关系"""
|
||||||
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
|
db_logger.debug(f"查询成员的工作空间: member_id={member_id}")
|
||||||
@@ -214,7 +237,8 @@ class WorkspaceRepository:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if member:
|
if member:
|
||||||
db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
db_logger.debug(
|
||||||
|
f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}")
|
||||||
else:
|
else:
|
||||||
db_logger.debug(f"成员不存在: member_id={member_id}")
|
db_logger.debug(f"成员不存在: member_id={member_id}")
|
||||||
return member
|
return member
|
||||||
@@ -222,7 +246,8 @@ class WorkspaceRepository:
|
|||||||
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
|
db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[
|
||||||
|
WorkspaceMember]:
|
||||||
try:
|
try:
|
||||||
member = self.db.query(WorkspaceMember).filter(
|
member = self.db.query(WorkspaceMember).filter(
|
||||||
WorkspaceMember.workspace_id == workspace_id,
|
WorkspaceMember.workspace_id == workspace_id,
|
||||||
@@ -255,7 +280,7 @@ class WorkspaceRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]:
|
||||||
try:
|
try:
|
||||||
member = self.db.query(WorkspaceMember).filter(
|
member = self.db.query(WorkspaceMember).filter(
|
||||||
@@ -271,7 +296,7 @@ class WorkspaceRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
|
db_logger.error(f"删除成员失败: id={member_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]:
|
||||||
try:
|
try:
|
||||||
member = self.db.query(WorkspaceMember).filter(
|
member = self.db.query(WorkspaceMember).filter(
|
||||||
@@ -288,12 +313,18 @@ class WorkspaceRepository:
|
|||||||
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
|
db_logger.error(f"更新成员角色失败: id={id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# 保持向后兼容的函数
|
# 保持向后兼容的函数
|
||||||
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
|
def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.get_workspace_by_id(workspace_id)
|
return repo.get_workspace_by_id(workspace_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_workspaces_by_name(db: Session, tenant_id: uuid.UUID, name: str) -> List[Workspace]:
|
||||||
|
repo = WorkspaceRepository(db)
|
||||||
|
return repo.get_workspaces_by_name(tenant_id, name)
|
||||||
|
|
||||||
|
|
||||||
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
|
def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.get_workspaces_by_user(user_id)
|
return repo.get_workspaces_by_user(user_id)
|
||||||
@@ -315,7 +346,7 @@ def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UU
|
|||||||
|
|
||||||
|
|
||||||
def add_member_to_workspace(
|
def add_member_to_workspace(
|
||||||
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole
|
||||||
) -> WorkspaceMember:
|
) -> WorkspaceMember:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.add_member(workspace_id, user_id, role)
|
return repo.add_member(workspace_id, user_id, role)
|
||||||
@@ -325,39 +356,43 @@ def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[Works
|
|||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.get_members_by_workspace(workspace_id)
|
return repo.get_members_by_workspace(workspace_id)
|
||||||
|
|
||||||
|
|
||||||
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
|
def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.get_member_by_id(member_id)
|
return repo.get_member_by_id(member_id)
|
||||||
|
|
||||||
|
|
||||||
def update_member_role_in_workspace(
|
def update_member_role_in_workspace(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
role: WorkspaceRole,
|
role: WorkspaceRole,
|
||||||
) -> Optional[WorkspaceMember]:
|
) -> Optional[WorkspaceMember]:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.update_member_role(workspace_id, user_id, role)
|
return repo.update_member_role(workspace_id, user_id, role)
|
||||||
|
|
||||||
|
|
||||||
def remove_member_from_workspace(
|
def remove_member_from_workspace(
|
||||||
db: Session,
|
db: Session,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Optional[WorkspaceMember]:
|
) -> Optional[WorkspaceMember]:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.deactivate_member(workspace_id, user_id)
|
return repo.deactivate_member(workspace_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
def remove_member_from_workspace_by_id(
|
def remove_member_from_workspace_by_id(
|
||||||
db: Session,
|
db: Session,
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
) -> Optional[WorkspaceMember]:
|
) -> Optional[WorkspaceMember]:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.delete_member_by_id(member_id)
|
return repo.delete_member_by_id(member_id)
|
||||||
|
|
||||||
|
|
||||||
def update_member_role_by_id(
|
def update_member_role_by_id(
|
||||||
db: Session,
|
db: Session,
|
||||||
id: uuid.UUID,
|
id: uuid.UUID,
|
||||||
role: WorkspaceRole,
|
role: WorkspaceRole,
|
||||||
) -> Optional[WorkspaceMember]:
|
) -> Optional[WorkspaceMember]:
|
||||||
repo = WorkspaceRepository(db)
|
repo = WorkspaceRepository(db)
|
||||||
return repo.update_member_role_by_id(id, role)
|
return repo.update_member_role_by_id(id, role)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel):
|
|||||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||||
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||||
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
||||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ class ChatResponse(BaseModel):
|
|||||||
"""聊天响应(非流式)"""
|
"""聊天响应(非流式)"""
|
||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID
|
||||||
message: str
|
message: str
|
||||||
|
message_id: str
|
||||||
usage: Optional[Dict[str, Any]] = None
|
usage: Optional[Dict[str, Any]] = None
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
|
|
||||||
|
|||||||
@@ -417,6 +417,7 @@ class MemoryConfig:
|
|||||||
|
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id: Optional[UUID] = None
|
scene_id: Optional[UUID] = None
|
||||||
|
ontology_classes: Optional[list] = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Validate configuration after initialization."""
|
"""Validate configuration after initialization."""
|
||||||
|
|||||||
@@ -232,14 +232,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
|||||||
# 本体场景关联(可选)
|
# 本体场景关联(可选)
|
||||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
||||||
|
|
||||||
|
# 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入)
|
||||||
|
pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充")
|
||||||
|
|
||||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||||
|
|
||||||
|
|
||||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||||
@@ -274,8 +275,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
|||||||
|
|
||||||
# 剪枝配置:与 runtime.json 中 pruning 段对应
|
# 剪枝配置:与 runtime.json 中 pruning 段对应
|
||||||
pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝")
|
pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝")
|
||||||
pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field(
|
pruning_scene: Optional[str] = Field(
|
||||||
None, description="智能剪枝场景:education/online_service/outbound"
|
None, description="智能剪枝场景:education/online_service/outbound 或本体工程自定义场景"
|
||||||
)
|
)
|
||||||
pruning_threshold: Optional[float] = Field(
|
pruning_threshold: Optional[float] = Field(
|
||||||
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||||
|
model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID")
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyCreateNested(BaseModel):
|
class ApiKeyCreateNested(BaseModel):
|
||||||
@@ -116,8 +117,8 @@ class ModelApiKeyBase(BaseModel):
|
|||||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
priority: str = Field("1", description="优先级", max_length=10)
|
priority: str = Field("1", description="优先级", max_length=10)
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
|
|||||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||||
classes_count: int = Field(0, description="类型数量")
|
classes_count: int = Field(0, description="类型数量")
|
||||||
|
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
|
|||||||
scene_id: UUID = Field(..., description="所属场景ID")
|
scene_id: UUID = Field(..., description="所属场景ID")
|
||||||
scene_name: str = Field(..., description="场景名称")
|
scene_name: str = Field(..., description="场景名称")
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
|
|||||||
health_status: str = "unknown"
|
health_status: str = "unknown"
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
||||||
|
source_channel: Optional[str] = Field(None, description="来源渠道")
|
||||||
|
market_id: Optional[str] = Field(None, description="渠道市场id")
|
||||||
|
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
|
||||||
|
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
|
|||||||
tool_type: ToolType
|
tool_type: ToolType
|
||||||
config: Dict[str, Any] = Field(default_factory=dict)
|
config: Dict[str, Any] = Field(default_factory=dict)
|
||||||
tags: List[str] = Field(default_factory=list)
|
tags: List[str] = Field(default_factory=list)
|
||||||
|
source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)")
|
||||||
|
market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)")
|
||||||
|
market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)")
|
||||||
|
mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)")
|
||||||
|
|
||||||
|
|
||||||
class ToolUpdateRequest(BaseModel):
|
class ToolUpdateRequest(BaseModel):
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
self.conversation_service.save_conversation_messages(
|
message_id = self.conversation_service.save_conversation_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"],
|
assistant_message=result["content"],
|
||||||
@@ -163,6 +163,7 @@ class AppChatService:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
|
"message_id": str(message_id),
|
||||||
"message": result["content"],
|
"message": result["content"],
|
||||||
"usage": result.get("usage", {
|
"usage": result.get("usage", {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
@@ -191,7 +192,11 @@ class AppChatService:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
config_id = None
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
message_id = uuid.uuid4()
|
||||||
|
yield f"event: start\ndata: {json.dumps({
|
||||||
|
'conversation_id': str(conversation_id),
|
||||||
|
"message_id": str(message_id)
|
||||||
|
}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
@@ -296,6 +301,7 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
|
message_id=message_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=full_content,
|
content=full_content,
|
||||||
@@ -373,7 +379,7 @@ class AppChatService:
|
|||||||
content=message
|
content=message
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conversation_service.add_message(
|
ai_message = self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=result.get("message", ""),
|
content=result.get("message", ""),
|
||||||
@@ -391,6 +397,7 @@ class AppChatService:
|
|||||||
return {
|
return {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"message": result.get("message", ""),
|
"message": result.get("message", ""),
|
||||||
|
"message_id": str(ai_message.id),
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
@@ -419,9 +426,9 @@ class AppChatService:
|
|||||||
variables = {}
|
variables = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
message_id = uuid.uuid4()
|
||||||
# 发送开始事件
|
# 发送开始事件
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
@@ -429,6 +436,7 @@ class AppChatService:
|
|||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
|
||||||
# 3. 流式执行任务
|
# 3. 流式执行任务
|
||||||
async for event in orchestrator.execute_stream(
|
async for event in orchestrator.execute_stream(
|
||||||
message=message,
|
message=message,
|
||||||
@@ -472,6 +480,7 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
|
message_id=message_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=full_content,
|
content=full_content,
|
||||||
|
|||||||
390
api/app/services/app_dsl_service.py
Normal file
390
api/app/services/app_dsl_service.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
"""应用 DSL 导入导出服务"""
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
|
from app.models import AgentConfig, MultiAgentConfig
|
||||||
|
from app.models.app_model import App, AppType
|
||||||
|
from app.models.app_release_model import AppRelease
|
||||||
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.models.models_model import ModelConfig
|
||||||
|
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||||
|
from app.models.workflow_model import WorkflowConfig
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
|
class AppDslService:
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
# ==================== 导出 ====================
|
||||||
|
|
||||||
|
def export_dsl(self, app_id: uuid.UUID, release_id: Optional[uuid.UUID] = None) -> tuple[str, str]:
|
||||||
|
"""构建应用 DSL yaml 字符串,返回 (yaml_str, filename)"""
|
||||||
|
app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
|
||||||
|
if not app:
|
||||||
|
raise ResourceNotFoundException("应用", str(app_id))
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"version": settings.SYSTEM_VERSION,
|
||||||
|
"platform": "MemoryBear",
|
||||||
|
"exported_at": datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
}
|
||||||
|
app_meta = {
|
||||||
|
"name": app.name,
|
||||||
|
"description": app.description,
|
||||||
|
"icon": app.icon,
|
||||||
|
"icon_type": app.icon_type,
|
||||||
|
"type": app.type,
|
||||||
|
"tags": app.tags or [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if release_id is not None:
|
||||||
|
return self._export_release(app, release_id, meta, app_meta)
|
||||||
|
|
||||||
|
return self._export_draft(app, meta, app_meta)
|
||||||
|
|
||||||
|
def _export_release(self, app: App, release_id: uuid.UUID, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
|
release = self.db.query(AppRelease).filter(
|
||||||
|
AppRelease.app_id == app.id,
|
||||||
|
AppRelease.id == release_id,
|
||||||
|
AppRelease.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
if not release:
|
||||||
|
raise ResourceNotFoundException("版本", str(release_id))
|
||||||
|
|
||||||
|
meta["release_version"] = release.version
|
||||||
|
meta["release_name"] = release.version_name
|
||||||
|
app_meta["name"] = release.name
|
||||||
|
app_meta["description"] = release.description
|
||||||
|
config_key = {
|
||||||
|
AppType.AGENT: "agent_config",
|
||||||
|
AppType.MULTI_AGENT: "multi_agent_config",
|
||||||
|
AppType.WORKFLOW: "workflow"
|
||||||
|
}.get(app.type, "config")
|
||||||
|
config_data = self._enrich_release_config(app.type, release.config or {})
|
||||||
|
dsl = {**meta, "app": app_meta, config_key: config_data}
|
||||||
|
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{release.name}_v{release.version_name}.yaml"
|
||||||
|
|
||||||
|
def _enrich_release_config(self, app_type: str, cfg: dict) -> dict:
|
||||||
|
if app_type == AppType.AGENT:
|
||||||
|
enriched = {**cfg}
|
||||||
|
if "default_model_config_id" in cfg:
|
||||||
|
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||||
|
if "knowledge_retrieval" in cfg:
|
||||||
|
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||||
|
if "tools" in cfg:
|
||||||
|
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
||||||
|
return enriched
|
||||||
|
if app_type == AppType.MULTI_AGENT:
|
||||||
|
enriched = {**cfg}
|
||||||
|
if "default_model_config_id" in cfg:
|
||||||
|
enriched["default_model_config_ref"] = self._model_ref(cfg["default_model_config_id"])
|
||||||
|
if "master_agent_id" in cfg:
|
||||||
|
enriched["master_agent_ref"] = self._release_ref(cfg["master_agent_id"])
|
||||||
|
if "sub_agents" in cfg:
|
||||||
|
enriched["sub_agents"] = self._enrich_sub_agents(cfg["sub_agents"])
|
||||||
|
if "routing_rules" in cfg:
|
||||||
|
enriched["routing_rules"] = [
|
||||||
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||||
|
]
|
||||||
|
return enriched
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
|
if app.type == AppType.WORKFLOW:
|
||||||
|
config = self.db.query(WorkflowConfig).filter(WorkflowConfig.app_id == app.id).first()
|
||||||
|
config_data = {
|
||||||
|
"variables": config.variables if config else [],
|
||||||
|
"edges": config.edges if config else [],
|
||||||
|
"nodes": config.nodes if config else [],
|
||||||
|
"execution_config": config.execution_config if config else {},
|
||||||
|
"triggers": config.triggers if config else [],
|
||||||
|
} if config else {}
|
||||||
|
dsl = {**meta, "app": app_meta, "workflow": config_data}
|
||||||
|
|
||||||
|
elif app.type == AppType.AGENT:
|
||||||
|
config = self.db.query(AgentConfig).filter(AgentConfig.app_id == app.id).first()
|
||||||
|
config_data = {
|
||||||
|
"system_prompt": config.system_prompt if config else None,
|
||||||
|
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||||
|
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||||
|
"knowledge_retrieval": self._enrich_knowledge_retrieval(config.knowledge_retrieval) if config else None,
|
||||||
|
"memory": config.memory if config else None,
|
||||||
|
"variables": config.variables if config else [],
|
||||||
|
"tools": self._enrich_tools(config.tools) if config else [],
|
||||||
|
"skills": config.skills if config else {},
|
||||||
|
} if config else {}
|
||||||
|
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||||
|
|
||||||
|
elif app.type == AppType.MULTI_AGENT:
|
||||||
|
config = self.db.query(MultiAgentConfig).filter(MultiAgentConfig.app_id == app.id).first()
|
||||||
|
config_data = {
|
||||||
|
"orchestration_mode": config.orchestration_mode if config else None,
|
||||||
|
"master_agent_name": config.master_agent_name if config else None,
|
||||||
|
"model_parameters": self._to_dict(config.model_parameters) if config else None,
|
||||||
|
"default_model_config_ref": self._model_ref(config.default_model_config_id) if config else None,
|
||||||
|
"master_agent_ref": self._release_ref(config.master_agent_id) if config else None,
|
||||||
|
"sub_agents": self._enrich_sub_agents(config.sub_agents) if config else [],
|
||||||
|
"routing_rules": [
|
||||||
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (config.routing_rules or [])
|
||||||
|
] if config else [],
|
||||||
|
|
||||||
|
"execution_config": config.execution_config if config else {},
|
||||||
|
"aggregation_strategy": config.aggregation_strategy if config else "merge",
|
||||||
|
} if config else {}
|
||||||
|
dsl = {**meta, "app": app_meta, "multi_agent_config": config_data}
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise BusinessException(f"不支持的应用类型: {app.type}", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
return yaml.dump(dsl, default_flow_style=False, allow_unicode=True), f"{app.name}.yaml"
|
||||||
|
|
||||||
|
def _to_dict(self, value):
|
||||||
|
"""将 Pydantic 对象转为普通 dict,供 yaml.dump 安全序列化"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if hasattr(value, "model_dump"):
|
||||||
|
return value.model_dump()
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _model_ref(self, model_config_id) -> Optional[dict]:
|
||||||
|
if not model_config_id:
|
||||||
|
return None
|
||||||
|
m = self.db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
|
||||||
|
return {"id": str(model_config_id), "name": m.name, "provider": m.provider, "type": m.type} if m else {"id": str(model_config_id)}
|
||||||
|
|
||||||
|
def _kb_ref(self, kb_id) -> Optional[dict]:
|
||||||
|
if not kb_id:
|
||||||
|
return None
|
||||||
|
kb = self.db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
|
return {"id": str(kb_id), "name": kb.name} if kb else {"id": str(kb_id)}
|
||||||
|
|
||||||
|
def _tool_ref(self, tool_id) -> Optional[dict]:
|
||||||
|
if not tool_id:
|
||||||
|
return None
|
||||||
|
t = self.db.query(ToolConfigModel).filter(ToolConfigModel.id == tool_id).first()
|
||||||
|
return {"id": str(tool_id), "name": t.name, "tool_type": t.tool_type} if t else {"id": str(tool_id)}
|
||||||
|
|
||||||
|
def _enrich_knowledge_retrieval(self, kr: Optional[dict]) -> Optional[dict]:
|
||||||
|
if not kr:
|
||||||
|
return kr
|
||||||
|
kbs = [{**kb, "_ref": self._kb_ref(kb.get("kb_id"))} for kb in kr.get("knowledge_bases", [])]
|
||||||
|
return {**kr, "knowledge_bases": kbs}
|
||||||
|
|
||||||
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
|
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||||
|
if not agent_id:
|
||||||
|
return None
|
||||||
|
a = self.db.query(App).filter(App.id == agent_id).first()
|
||||||
|
return {"id": str(agent_id), "name": a.name} if a else {"id": str(agent_id)}
|
||||||
|
|
||||||
|
def _release_ref(self, release_id) -> Optional[dict]:
|
||||||
|
if not release_id:
|
||||||
|
return None
|
||||||
|
r = self.db.query(AppRelease).filter(AppRelease.id == release_id).first()
|
||||||
|
return {"id": str(release_id), "name": r.name, "version": r.version, "app_id": str(r.app_id)} if r else {"id": str(release_id)}
|
||||||
|
|
||||||
|
def _enrich_sub_agents(self, sub_agents: list) -> list:
|
||||||
|
return [{**s, "_ref": self._agent_ref(s.get("agent_id"))} for s in (sub_agents or [])]
|
||||||
|
|
||||||
|
# ==================== 导入 ====================
|
||||||
|
|
||||||
|
def import_dsl(
|
||||||
|
self,
|
||||||
|
dsl: dict,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
) -> tuple[App, list[str]]:
|
||||||
|
"""解析 DSL,创建应用及配置,返回 (new_app, warnings)"""
|
||||||
|
app_meta = dsl.get("app", {})
|
||||||
|
app_type = app_meta.get("type")
|
||||||
|
if app_type not in (AppType.AGENT, AppType.MULTI_AGENT, AppType.WORKFLOW):
|
||||||
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
warnings: list[str] = []
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
|
new_app = App(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
created_by=user_id,
|
||||||
|
name=app_meta.get("name", "导入应用"),
|
||||||
|
description=app_meta.get("description"),
|
||||||
|
icon=app_meta.get("icon"),
|
||||||
|
icon_type=app_meta.get("icon_type"),
|
||||||
|
type=app_type,
|
||||||
|
visibility="private",
|
||||||
|
status="draft",
|
||||||
|
tags=app_meta.get("tags", []),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
self.db.add(new_app)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
if app_type == AppType.AGENT:
|
||||||
|
cfg = dsl.get("agent_config") or {}
|
||||||
|
self.db.add(AgentConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=new_app.id,
|
||||||
|
system_prompt=cfg.get("system_prompt"),
|
||||||
|
model_parameters=cfg.get("model_parameters"),
|
||||||
|
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||||
|
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
|
||||||
|
memory=cfg.get("memory"),
|
||||||
|
variables=cfg.get("variables", []),
|
||||||
|
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||||
|
skills=cfg.get("skills", {}),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
))
|
||||||
|
|
||||||
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
cfg = dsl.get("multi_agent_config") or {}
|
||||||
|
self.db.add(MultiAgentConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=new_app.id,
|
||||||
|
orchestration_mode=cfg.get("orchestration_mode", "collaboration"),
|
||||||
|
master_agent_name=cfg.get("master_agent_name"),
|
||||||
|
model_parameters=cfg.get("model_parameters"),
|
||||||
|
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||||
|
master_agent_id=self._resolve_release(cfg.get("master_agent_ref"), warnings),
|
||||||
|
sub_agents=self._resolve_sub_agents(cfg.get("sub_agents", []), warnings),
|
||||||
|
routing_rules=self._resolve_routing_rules(cfg.get("routing_rules"), warnings),
|
||||||
|
execution_config=cfg.get("execution_config", {}),
|
||||||
|
aggregation_strategy=cfg.get("aggregation_strategy", "merge"),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
))
|
||||||
|
|
||||||
|
elif app_type == AppType.WORKFLOW:
|
||||||
|
wf = dsl.get("workflow") or {}
|
||||||
|
WorkflowService(self.db).create_workflow_config(
|
||||||
|
app_id=new_app.id,
|
||||||
|
nodes=wf.get("nodes", []),
|
||||||
|
edges=wf.get("edges", []),
|
||||||
|
variables=wf.get("variables", []),
|
||||||
|
execution_config=wf.get("execution_config", {}),
|
||||||
|
triggers=wf.get("triggers", []),
|
||||||
|
validate=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(new_app)
|
||||||
|
return new_app, warnings
|
||||||
|
|
||||||
|
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
|
||||||
|
if not ref:
|
||||||
|
return None
|
||||||
|
q = self.db.query(ModelConfig).filter(
|
||||||
|
ModelConfig.tenant_id == tenant_id,
|
||||||
|
ModelConfig.name == ref.get("name"),
|
||||||
|
ModelConfig.is_active.is_(True)
|
||||||
|
)
|
||||||
|
if ref.get("provider"):
|
||||||
|
q = q.filter(ModelConfig.provider == ref["provider"])
|
||||||
|
if ref.get("type"):
|
||||||
|
q = q.filter(ModelConfig.type == ref["type"])
|
||||||
|
m = q.first()
|
||||||
|
if not m:
|
||||||
|
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return m.id if m else None
|
||||||
|
|
||||||
|
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||||
|
if not ref:
|
||||||
|
return None
|
||||||
|
kb = self.db.query(Knowledge).filter(
|
||||||
|
Knowledge.workspace_id == workspace_id,
|
||||||
|
Knowledge.name == ref.get("name")
|
||||||
|
).first()
|
||||||
|
if not kb:
|
||||||
|
warnings.append(f"知识库 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return str(kb.id) if kb else None
|
||||||
|
|
||||||
|
def _resolve_tool(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||||
|
if not ref:
|
||||||
|
return None
|
||||||
|
q = self.db.query(ToolConfigModel).filter(
|
||||||
|
ToolConfigModel.tenant_id == tenant_id,
|
||||||
|
ToolConfigModel.name == ref.get("name")
|
||||||
|
)
|
||||||
|
if ref.get("tool_type"):
|
||||||
|
q = q.filter(ToolConfigModel.tool_type == ref["tool_type"])
|
||||||
|
t = q.first()
|
||||||
|
if not t:
|
||||||
|
warnings.append(f"工具 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return str(t.id) if t else None
|
||||||
|
|
||||||
|
def _resolve_release(self, ref: Optional[dict], warnings: list) -> Optional[uuid.UUID]:
|
||||||
|
if not ref:
|
||||||
|
return None
|
||||||
|
r = self.db.query(AppRelease).filter(
|
||||||
|
AppRelease.app_id == ref.get("app_id"),
|
||||||
|
AppRelease.version == ref.get("version"),
|
||||||
|
AppRelease.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
if not r:
|
||||||
|
warnings.append(f"主 Agent 发布版本 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
return r.id if r else None
|
||||||
|
|
||||||
|
def _resolve_sub_agents(self, sub_agents: list, warnings: list) -> list:
|
||||||
|
result = []
|
||||||
|
for s in (sub_agents or []):
|
||||||
|
ref = s.get("_ref")
|
||||||
|
entry = {k: v for k, v in s.items() if k != "_ref"}
|
||||||
|
if ref:
|
||||||
|
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||||
|
if not a:
|
||||||
|
warnings.append(f"子 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
entry["agent_id"] = str(a.id) if a else None
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _resolve_routing_rules(self, rules: Optional[list], warnings: list) -> Optional[list]:
|
||||||
|
if rules is None:
|
||||||
|
return None
|
||||||
|
result = []
|
||||||
|
for r in rules:
|
||||||
|
ref = r.get("_ref")
|
||||||
|
entry = {k: v for k, v in r.items() if k != "_ref"}
|
||||||
|
if ref:
|
||||||
|
a = self.db.query(App).filter(App.name == ref.get("name"), App.is_active.is_(True)).first()
|
||||||
|
if not a:
|
||||||
|
warnings.append(f"路由目标 Agent '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
|
||||||
|
entry["target_agent_id"] = str(a.id) if a else None
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _resolve_knowledge_retrieval(self, kr: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||||
|
if not kr:
|
||||||
|
return kr
|
||||||
|
resolved_kbs = []
|
||||||
|
for kb in kr.get("knowledge_bases", []):
|
||||||
|
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
|
||||||
|
entry = {k: v for k, v in kb.items() if k != "_ref"}
|
||||||
|
entry["kb_id"] = self._resolve_kb(ref, workspace_id, warnings)
|
||||||
|
resolved_kbs.append(entry)
|
||||||
|
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
|
||||||
|
|
||||||
|
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||||
|
result = []
|
||||||
|
for t in (tools or []):
|
||||||
|
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
|
||||||
|
entry = {k: v for k, v in t.items() if k != "_ref"}
|
||||||
|
entry["tool_id"] = self._resolve_tool(ref, tenant_id, warnings)
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
@@ -33,7 +33,7 @@ from app.models import (
|
|||||||
Workspace,
|
Workspace,
|
||||||
)
|
)
|
||||||
from app.models.app_model import AppStatus, AppType
|
from app.models.app_model import AppStatus, AppType
|
||||||
from app.repositories.app_repository import get_apps_by_id
|
from app.repositories.app_repository import get_apps_by_id, AppRepository
|
||||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||||
from app.schemas import app_schema
|
from app.schemas import app_schema
|
||||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||||
@@ -59,6 +59,7 @@ class AppService:
|
|||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
"""
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.app_repo = AppRepository(self.db)
|
||||||
|
|
||||||
# ==================== 私有辅助方法 ====================
|
# ==================== 私有辅助方法 ====================
|
||||||
|
|
||||||
@@ -521,6 +522,9 @@ class AppService:
|
|||||||
"创建应用",
|
"创建应用",
|
||||||
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
|
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
|
||||||
)
|
)
|
||||||
|
apps = self.app_repo.get_apps_by_name(data.name, data.type, workspace_id)
|
||||||
|
if apps:
|
||||||
|
raise BusinessException(message="已存在同名应用", code=BizCode.RESOURCE_ALREADY_EXISTS)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
@@ -703,7 +707,7 @@ class AppService:
|
|||||||
self.db.flush()
|
self.db.flush()
|
||||||
|
|
||||||
# 如果是 agent 类型,复制 AgentConfig
|
# 如果是 agent 类型,复制 AgentConfig
|
||||||
if source_app.type == "agent":
|
if source_app.type == AppType.AGENT:
|
||||||
source_config = self.db.query(AgentConfig).filter(
|
source_config = self.db.query(AgentConfig).filter(
|
||||||
AgentConfig.app_id == source_app.id
|
AgentConfig.app_id == source_app.id
|
||||||
).first()
|
).first()
|
||||||
@@ -725,6 +729,50 @@ class AppService:
|
|||||||
)
|
)
|
||||||
self.db.add(new_config)
|
self.db.add(new_config)
|
||||||
|
|
||||||
|
elif source_app.type == AppType.WORKFLOW:
|
||||||
|
source_config = self.db.query(WorkflowConfig).filter(
|
||||||
|
WorkflowConfig.app_id == source_app.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if source_config:
|
||||||
|
new_config = WorkflowConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=new_app.id,
|
||||||
|
nodes=source_config.nodes.copy() if source_config.nodes else [],
|
||||||
|
edges=source_config.edges.copy() if source_config.edges else [],
|
||||||
|
variables=source_config.variables.copy() if source_config.variables else [],
|
||||||
|
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
||||||
|
triggers=source_config.triggers.copy() if source_config.triggers else [],
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
self.db.add(new_config)
|
||||||
|
|
||||||
|
elif source_app.type == AppType.MULTI_AGENT:
|
||||||
|
source_config = self.db.query(MultiAgentConfig).filter(
|
||||||
|
MultiAgentConfig.app_id == source_app.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if source_config:
|
||||||
|
new_config = MultiAgentConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=new_app.id,
|
||||||
|
master_agent_id=source_config.master_agent_id,
|
||||||
|
master_agent_name=source_config.master_agent_name,
|
||||||
|
default_model_config_id=source_config.default_model_config_id,
|
||||||
|
model_parameters=source_config.model_parameters,
|
||||||
|
orchestration_mode=source_config.orchestration_mode,
|
||||||
|
sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [],
|
||||||
|
routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None,
|
||||||
|
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
||||||
|
aggregation_strategy=source_config.aggregation_strategy,
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
self.db.add(new_config)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(new_app)
|
self.db.refresh(new_app)
|
||||||
|
|
||||||
@@ -1324,6 +1372,15 @@ class AppService:
|
|||||||
if not agent_cfg:
|
if not agent_cfg:
|
||||||
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
miss_params = []
|
||||||
|
if agent_cfg.default_model_config_id is None:
|
||||||
|
miss_params.append("model config")
|
||||||
|
|
||||||
|
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||||
|
miss_params.append("memory config")
|
||||||
|
if miss_params:
|
||||||
|
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"system_prompt": agent_cfg.system_prompt,
|
"system_prompt": agent_cfg.system_prompt,
|
||||||
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),
|
"model_parameters": model_parameters_to_dict(agent_cfg.model_parameters),
|
||||||
|
|||||||
@@ -178,7 +178,8 @@ class ConversationService:
|
|||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
role: str,
|
role: str,
|
||||||
content: str,
|
content: str,
|
||||||
meta_data: Optional[dict] = None
|
meta_data: Optional[dict] = None,
|
||||||
|
message_id: Optional[uuid.UUID] = None,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""
|
"""
|
||||||
Add a message to a conversation using UnitOfWork.
|
Add a message to a conversation using UnitOfWork.
|
||||||
@@ -188,6 +189,7 @@ class ConversationService:
|
|||||||
role (str): Role of the message sender ('user' or 'assistant').
|
role (str): Role of the message sender ('user' or 'assistant').
|
||||||
content (str): Message content.
|
content (str): Message content.
|
||||||
meta_data (Optional[dict]): Optional metadata.
|
meta_data (Optional[dict]): Optional metadata.
|
||||||
|
message_id (Optional[uuid.UUID]): Optional custom message UUID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Message: Newly created Message instance.
|
Message: Newly created Message instance.
|
||||||
@@ -198,6 +200,7 @@ class ConversationService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
|
id=message_id if message_id else uuid.uuid4(),
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role=role,
|
role=role,
|
||||||
content=content,
|
content=content,
|
||||||
@@ -317,7 +320,7 @@ class ConversationService:
|
|||||||
content=user_message
|
content=user_message
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_message(
|
ai_message = self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message,
|
content=assistant_message,
|
||||||
@@ -332,6 +335,7 @@ class ConversationService:
|
|||||||
"assistant_message_length": len(assistant_message)
|
"assistant_message_length": len(assistant_message)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
return ai_message.id
|
||||||
|
|
||||||
def delete_conversation(
|
def delete_conversation(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
|
|||||||
"""
|
"""
|
||||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
with get_db_context() as db:
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
memory_content = asyncio.run(
|
memory_content = asyncio.run(
|
||||||
MemoryAgentService().read_memory(
|
MemoryAgentService().read_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
|
|||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
if memory_content:
|
if memory_content:
|
||||||
memory_content = memory_content['answer']
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
|||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
|||||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=True,
|
||||||
duration=duration, details={"message_length": len(message)})
|
duration=duration, details={"message_length": len(message)})
|
||||||
return context
|
return context
|
||||||
else:
|
else:
|
||||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(f"写入失败: {messages}")
|
raise ValueError(f"写入失败: {messages}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||||
"""Extract tool call information from event"""
|
"""Extract tool call information from event"""
|
||||||
last_message = event["messages"][-1]
|
last_message = event["messages"][-1]
|
||||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
|||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
@@ -351,9 +350,9 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
print(langchain_messages)
|
print(langchain_messages)
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
@@ -375,29 +374,28 @@ class MemoryAgentService:
|
|||||||
contents = massages.get('write_result')
|
contents = massages.get('write_result')
|
||||||
# Convert messages back to string for logging
|
# Convert messages back to string for logging
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||||
|
contents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
config_id: Optional[uuid.UUID]|int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str) -> Dict:
|
user_rag_memory_id: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Process read operation with config_id
|
Process read operation with config_id
|
||||||
|
|
||||||
@@ -425,7 +423,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ori_message= message
|
ori_message = message
|
||||||
|
|
||||||
# Resolve config_id and workspace_id
|
# Resolve config_id and workspace_id
|
||||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
# Use a separate database session to avoid transaction failures
|
# Use a separate database session to avoid transaction failures
|
||||||
@@ -562,34 +560,35 @@ class MemoryAgentService:
|
|||||||
from app.repositories.memory_short_repository import (
|
from app.repositories.memory_short_repository import (
|
||||||
ShortTermMemoryRepository,
|
ShortTermMemoryRepository,
|
||||||
)
|
)
|
||||||
|
|
||||||
retrieved_content = []
|
retrieved_content = []
|
||||||
repo = ShortTermMemoryRepository(db)
|
repo = ShortTermMemoryRepository(db)
|
||||||
|
|
||||||
if str(search_switch) != "2":
|
if str(search_switch) != "2":
|
||||||
for intermediate in _intermediate_outputs:
|
for intermediate in _intermediate_outputs:
|
||||||
logger.debug(f"处理中间结果: {intermediate}")
|
logger.debug(f"处理中间结果: {intermediate}")
|
||||||
intermediate_type = intermediate.get('type', '')
|
intermediate_type = intermediate.get('type', '')
|
||||||
|
|
||||||
if intermediate_type == "search_result":
|
if intermediate_type == "search_result":
|
||||||
query = intermediate.get('query', '')
|
query = intermediate.get('query', '')
|
||||||
raw_results = intermediate.get('raw_results', {})
|
raw_results = intermediate.get('raw_results', {})
|
||||||
try:
|
try:
|
||||||
reranked_results = raw_results.get('reranked_results', [])
|
reranked_results = raw_results.get('reranked_results', [])
|
||||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
statements = [statement['statement'] for statement in
|
||||||
|
reranked_results.get('statements', [])]
|
||||||
except Exception:
|
except Exception:
|
||||||
statements = []
|
statements = []
|
||||||
|
|
||||||
# 去重
|
# 去重
|
||||||
statements = list(set(statements))
|
statements = list(set(statements))
|
||||||
|
|
||||||
if query and statements:
|
if query and statements:
|
||||||
retrieved_content.append({query: statements})
|
retrieved_content.append({query: statements})
|
||||||
|
|
||||||
# 如果 retrieved_content 为空,设置为空字符串
|
# 如果 retrieved_content 为空,设置为空字符串
|
||||||
if retrieved_content == []:
|
if retrieved_content == []:
|
||||||
retrieved_content = ''
|
retrieved_content = ''
|
||||||
|
|
||||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||||
# 使用 upsert 方法
|
# 使用 upsert 方法
|
||||||
@@ -602,15 +601,17 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
logger.debug(
|
||||||
|
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||||
|
|
||||||
except Exception as save_error:
|
except Exception as save_error:
|
||||||
# 保存失败不应该影响主流程,只记录错误
|
# 保存失败不应该影响主流程,只记录错误
|
||||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||||
|
|
||||||
# Log successful operation
|
# Log successful operation
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
logger.info(
|
||||||
|
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get standardized message list from user input.
|
Get standardized message list from user input.
|
||||||
@@ -657,41 +657,43 @@ class MemoryAgentService:
|
|||||||
"""
|
"""
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
logger = get_api_logger()
|
logger = get_api_logger()
|
||||||
|
|
||||||
if len(user_input.messages) == 0:
|
if len(user_input.messages) == 0:
|
||||||
logger.error("Validation failed: Message list cannot be empty")
|
logger.error("Validation failed: Message list cannot be empty")
|
||||||
raise ValueError("Message list cannot be empty")
|
raise ValueError("Message list cannot be empty")
|
||||||
|
|
||||||
for idx, msg in enumerate(user_input.messages):
|
for idx, msg in enumerate(user_input.messages):
|
||||||
if not isinstance(msg, dict):
|
if not isinstance(msg, dict):
|
||||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||||
|
|
||||||
if 'role' not in msg:
|
if 'role' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||||
|
|
||||||
if 'content' not in msg:
|
if 'content' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||||
|
|
||||||
if msg['role'] not in ['user', 'assistant']:
|
if msg['role'] not in ['user', 'assistant']:
|
||||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||||
|
|
||||||
if not msg['content'] or not msg['content'].strip():
|
if not msg['content'] or not msg['content'].strip():
|
||||||
logger.error(f"Validation failed: Message {idx} content is empty")
|
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||||
|
|
||||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||||
return user_input.messages
|
return user_input.messages
|
||||||
|
|
||||||
async def classify_message_type(
|
async def classify_message_type(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: UUID,
|
config_id: UUID,
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -719,14 +721,15 @@ class MemoryAgentService:
|
|||||||
status = await status_typle(message, memory_config.llm_model_id)
|
status = await status_typle(message, memory_config.llm_model_id)
|
||||||
logger.debug(f"Message type: {status}")
|
logger.debug(f"Message type: {status}")
|
||||||
return status
|
return status
|
||||||
|
|
||||||
async def generate_summary_from_retrieve(
|
async def generate_summary_from_retrieve(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
retrieve_info: str,
|
retrieve_info: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
query: str,
|
query: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
db: Session
|
db: Session
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
基于检索信息、历史对话和查询生成最终答案
|
基于检索信息、历史对话和查询生成最终答案
|
||||||
@@ -761,9 +764,9 @@ class MemoryAgentService:
|
|||||||
if config_id is None:
|
if config_id is None:
|
||||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||||
# If config_id was provided, continue without workspace_id fallback
|
# If config_id was provided, continue without workspace_id fallback
|
||||||
|
|
||||||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 加载配置
|
# 加载配置
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
@@ -772,7 +775,7 @@ class MemoryAgentService:
|
|||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入必要的模块
|
# 导入必要的模块
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||||
summary_llm,
|
summary_llm,
|
||||||
@@ -780,13 +783,13 @@ class MemoryAgentService:
|
|||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建状态对象
|
# 构建状态对象
|
||||||
state = {
|
state = {
|
||||||
"data": query,
|
"data": query,
|
||||||
"memory_config": memory_config
|
"memory_config": memory_config
|
||||||
}
|
}
|
||||||
|
|
||||||
# 直接调用 summary_llm 函数
|
# 直接调用 summary_llm 函数
|
||||||
answer = await summary_llm(
|
answer = await summary_llm(
|
||||||
state=state,
|
state=state,
|
||||||
@@ -797,21 +800,20 @@ class MemoryAgentService:
|
|||||||
response_model=RetrieveSummaryResponse,
|
response_model=RetrieveSummaryResponse,
|
||||||
search_mode="1"
|
search_mode="1"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||||
return answer if answer else "信息不足,无法回答。"
|
return answer if answer else "信息不足,无法回答。"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||||
return "信息不足,无法回答。"
|
return "信息不足,无法回答。"
|
||||||
|
|
||||||
|
|
||||||
async def get_knowledge_type_stats(
|
async def get_knowledge_type_stats(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
db: Session,
|
||||||
only_active: bool = True,
|
end_user_id: Optional[str] = None,
|
||||||
current_workspace_id: Optional[uuid.UUID] = None,
|
only_active: bool = True,
|
||||||
db: Session = None
|
current_workspace_id: Optional[uuid.UUID] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
统计知识库类型分布,包含:
|
统计知识库类型分布,包含:
|
||||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 1. 统计 PostgreSQL 中的知识库类型
|
# 1. 统计 PostgreSQL 中的知识库类型
|
||||||
try:
|
try:
|
||||||
if db is None:
|
|
||||||
from app.db import get_db
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 初始化所有标准类型为 0
|
# 初始化所有标准类型为 0
|
||||||
for kb_type in KnowledgeType:
|
for kb_type in KnowledgeType:
|
||||||
result[kb_type.value] = 0
|
result[kb_type.value] = 0
|
||||||
@@ -881,21 +878,19 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 3. 计算知识库类型总和(不包括 memory)
|
# 3. 计算知识库类型总和(不包括 memory)
|
||||||
result["total"] = (
|
result["total"] = (
|
||||||
result.get("General", 0) +
|
result.get("General", 0) +
|
||||||
result.get("Web", 0) +
|
result.get("Web", 0) +
|
||||||
result.get("Third-party", 0) +
|
result.get("Third-party", 0) +
|
||||||
result.get("Folder", 0)
|
result.get("Folder", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_interest_distribution_by_user(
|
async def get_interest_distribution_by_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签。
|
获取指定用户的兴趣分布标签。
|
||||||
@@ -921,13 +916,12 @@ class MemoryAgentService:
|
|||||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_profile(
|
async def get_user_profile(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user_id: Optional[str] = None,
|
current_user_id: Optional[str] = None,
|
||||||
llm_id: Optional[str] = None,
|
llm_id: Optional[str] = None,
|
||||||
db: Session = None
|
db: Session = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 定义标签提取的结构
|
# 定义标签提取的结构
|
||||||
class UserTags(BaseModel):
|
class UserTags(BaseModel):
|
||||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
tags: list[str] = Field(...,
|
||||||
|
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
ValueError: 当终端用户不存在或应用未发布时
|
ValueError: 当终端用户不存在或应用未发布时
|
||||||
"""
|
"""
|
||||||
import json as json_module
|
import json as json_module
|
||||||
import uuid
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -1171,6 +1165,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
|
|
||||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||||
|
|
||||||
|
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
|
||||||
# 1. 获取 end_user 及其 app_id
|
# 1. 获取 end_user 及其 app_id
|
||||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||||
if not end_user:
|
if not end_user:
|
||||||
@@ -1185,21 +1180,21 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found: {app_id}")
|
logger.warning(f"App not found: {app_id}")
|
||||||
raise ValueError(f"应用不存在: {app_id}")
|
raise ValueError(f"应用不存在: {app_id}")
|
||||||
|
# TODO: temp fix for draft run
|
||||||
if not app.current_release_id:
|
# if not app.current_release_id:
|
||||||
logger.warning(f"No current release for app: {app_id}")
|
# logger.warning(f"No current release for app: {app_id}")
|
||||||
raise ValueError(f"应用未发布: {app_id}")
|
# raise ValueError(f"应用未发布: {app_id}")
|
||||||
|
|
||||||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||||||
memory_config_id_to_use = end_user.memory_config_id
|
memory_config_id_to_use = end_user.memory_config_id
|
||||||
|
|
||||||
# 如果已有 memory_config_id,直接使用
|
# 如果已有 memory_config_id,直接使用
|
||||||
# 如果新创建enduser,enduser.memory_config_id 必定为none
|
# 如果新创建enduser,enduser.memory_config_id 必定为none
|
||||||
# 那么使用从release中获取memory_config_id为预期行为,并且回填到
|
# 那么使用从release中获取memory_config_id为预期行为,并且回填到
|
||||||
# end_user.memory_config_id
|
# end_user.memory_config_id
|
||||||
if not memory_config_id_to_use:
|
if not memory_config_id_to_use:
|
||||||
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
|
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
|
||||||
|
|
||||||
# 获取最新发布版本
|
# 获取最新发布版本
|
||||||
stmt = (
|
stmt = (
|
||||||
select(AppRelease)
|
select(AppRelease)
|
||||||
@@ -1208,10 +1203,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
)
|
)
|
||||||
# TODO: change to current_release_id
|
# TODO: change to current_release_id
|
||||||
latest_release = db.scalars(stmt).first()
|
latest_release = db.scalars(stmt).first()
|
||||||
|
|
||||||
if latest_release:
|
if latest_release:
|
||||||
config = latest_release.config or {}
|
config = latest_release.config or {}
|
||||||
|
|
||||||
# 如果 config 是字符串,解析为字典
|
# 如果 config 是字符串,解析为字典
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
try:
|
try:
|
||||||
@@ -1219,22 +1214,24 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
except json_module.JSONDecodeError:
|
except json_module.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# 使用 MemoryConfigService 的提取方法
|
# 使用 MemoryConfigService 的提取方法
|
||||||
memory_config_service = MemoryConfigService(db)
|
memory_config_service = MemoryConfigService(db)
|
||||||
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
|
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
|
||||||
app_type=app.type,
|
app_type=app.type,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_config_id:
|
if legacy_config_id:
|
||||||
# 验证提取的 config_id 是否存在于数据库中
|
# 验证提取的 config_id 是否存在于数据库中
|
||||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
from app.models.memory_config_model import (
|
||||||
|
MemoryConfig as MemoryConfigModel,
|
||||||
|
)
|
||||||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||||||
|
|
||||||
if existing_config:
|
if existing_config:
|
||||||
memory_config_id_to_use = legacy_config_id
|
memory_config_id_to_use = legacy_config_id
|
||||||
|
|
||||||
# 回填到 end_user 表(lazy update)
|
# 回填到 end_user 表(lazy update)
|
||||||
end_user.memory_config_id = memory_config_id_to_use
|
end_user.memory_config_id = memory_config_id_to_use
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -1263,12 +1260,13 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
result = {
|
result = {
|
||||||
"end_user_id": str(end_user_id),
|
"end_user_id": str(end_user_id),
|
||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"release_id": str(app.current_release_id),
|
"release_id": str(app.current_release_id) if app.current_release_id else None,
|
||||||
"memory_config_id": memory_config_id,
|
"memory_config_id": memory_config_id,
|
||||||
"workspace_id": str(app.workspace_id)
|
"workspace_id": str(app.workspace_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
logger.info(
|
||||||
|
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -1312,7 +1310,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
|
|
||||||
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
|
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
|
||||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||||
|
|
||||||
# 创建映射 - 保留 EndUser 对象引用以便回填
|
# 创建映射 - 保留 EndUser 对象引用以便回填
|
||||||
end_user_map = {str(eu.id): eu for eu in end_users}
|
end_user_map = {str(eu.id): eu for eu in end_users}
|
||||||
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
|
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
|
||||||
@@ -1336,15 +1334,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
|
|
||||||
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
|
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
|
||||||
users_needing_migration = [
|
users_needing_migration = [
|
||||||
(end_user_id, data["app_id"])
|
(end_user_id, data["app_id"])
|
||||||
for end_user_id, data in user_data.items()
|
for end_user_id, data in user_data.items()
|
||||||
if not data["memory_config_id"]
|
if not data["memory_config_id"]
|
||||||
]
|
]
|
||||||
|
|
||||||
if users_needing_migration:
|
if users_needing_migration:
|
||||||
# 批量获取相关应用的最新发布版本
|
# 批量获取相关应用的最新发布版本
|
||||||
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
|
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
|
||||||
|
|
||||||
# 查询每个应用的最新活跃发布版本
|
# 查询每个应用的最新活跃发布版本
|
||||||
app_latest_releases = {}
|
app_latest_releases = {}
|
||||||
for app_id in migration_app_ids:
|
for app_id in migration_app_ids:
|
||||||
@@ -1357,18 +1355,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
latest_release = db.scalars(stmt).first()
|
latest_release = db.scalars(stmt).first()
|
||||||
if latest_release:
|
if latest_release:
|
||||||
app_latest_releases[app_id] = latest_release
|
app_latest_releases[app_id] = latest_release
|
||||||
|
|
||||||
# 为每个需要迁移的用户提取 memory_config_id
|
# 为每个需要迁移的用户提取 memory_config_id
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
users_to_backfill = [] # [(end_user, memory_config_id), ...]
|
users_to_backfill = [] # [(end_user, memory_config_id), ...]
|
||||||
|
|
||||||
for end_user_id, app_id in users_needing_migration:
|
for end_user_id, app_id in users_needing_migration:
|
||||||
latest_release = app_latest_releases.get(app_id)
|
latest_release = app_latest_releases.get(app_id)
|
||||||
if not latest_release:
|
if not latest_release:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
config = latest_release.config or {}
|
config = latest_release.config or {}
|
||||||
|
|
||||||
# 如果 config 是字符串,解析为字典
|
# 如果 config 是字符串,解析为字典
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
try:
|
try:
|
||||||
@@ -1376,21 +1374,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
except json_module.JSONDecodeError:
|
except json_module.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 使用 MemoryConfigService 的提取方法
|
# 使用 MemoryConfigService 的提取方法
|
||||||
app = app_map.get(app_id)
|
app = app_map.get(app_id)
|
||||||
if not app:
|
if not app:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
|
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
|
||||||
app_type=app.type,
|
app_type=app.type,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_config_id:
|
if legacy_config_id:
|
||||||
# 更新 user_data 中的 memory_config_id
|
# 更新 user_data 中的 memory_config_id
|
||||||
user_data[end_user_id]["memory_config_id"] = legacy_config_id
|
user_data[end_user_id]["memory_config_id"] = legacy_config_id
|
||||||
|
|
||||||
# 记录需要回填的用户(稍后验证配置存在后再回填)
|
# 记录需要回填的用户(稍后验证配置存在后再回填)
|
||||||
end_user = end_user_map.get(end_user_id)
|
end_user = end_user_map.get(end_user_id)
|
||||||
if end_user:
|
if end_user:
|
||||||
@@ -1399,7 +1397,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证提取的 config_id 是否存在于数据库中
|
# 验证提取的 config_id 是否存在于数据库中
|
||||||
if users_to_backfill:
|
if users_to_backfill:
|
||||||
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
|
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
|
||||||
@@ -1407,17 +1405,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
MemoryConfig.config_id.in_(config_ids_to_validate)
|
MemoryConfig.config_id.in_(config_ids_to_validate)
|
||||||
).all()
|
).all()
|
||||||
valid_config_ids = {mc.config_id for mc in existing_configs}
|
valid_config_ids = {mc.config_id for mc in existing_configs}
|
||||||
|
|
||||||
# 只回填存在的配置
|
# 只回填存在的配置
|
||||||
valid_backfills = [
|
valid_backfills = [
|
||||||
(eu, cid) for eu, cid in users_to_backfill
|
(eu, cid) for eu, cid in users_to_backfill
|
||||||
if cid in valid_config_ids
|
if cid in valid_config_ids
|
||||||
]
|
]
|
||||||
invalid_backfills = [
|
invalid_backfills = [
|
||||||
(eu, cid) for eu, cid in users_to_backfill
|
(eu, cid) for eu, cid in users_to_backfill
|
||||||
if cid not in valid_config_ids
|
if cid not in valid_config_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
if invalid_backfills:
|
if invalid_backfills:
|
||||||
invalid_ids = [str(cid) for _, cid in invalid_backfills]
|
invalid_ids = [str(cid) for _, cid in invalid_backfills]
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1426,7 +1424,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
# 清除 user_data 中无效的 config_id
|
# 清除 user_data 中无效的 config_id
|
||||||
for eu, cid in invalid_backfills:
|
for eu, cid in invalid_backfills:
|
||||||
user_data[str(eu.id)]["memory_config_id"] = None
|
user_data[str(eu.id)]["memory_config_id"] = None
|
||||||
|
|
||||||
# 批量回填 end_user.memory_config_id
|
# 批量回填 end_user.memory_config_id
|
||||||
if valid_backfills:
|
if valid_backfills:
|
||||||
for end_user, memory_config_id in valid_backfills:
|
for end_user, memory_config_id in valid_backfills:
|
||||||
@@ -1437,7 +1435,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
|
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
|
||||||
direct_config_ids = []
|
direct_config_ids = []
|
||||||
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
|
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
|
||||||
|
|
||||||
for end_user_id, data in user_data.items():
|
for end_user_id, data in user_data.items():
|
||||||
if data["memory_config_id"]:
|
if data["memory_config_id"]:
|
||||||
direct_config_ids.append(data["memory_config_id"])
|
direct_config_ids.append(data["memory_config_id"])
|
||||||
@@ -1455,7 +1453,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
|
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
|
||||||
workspace_default_configs = {}
|
workspace_default_configs = {}
|
||||||
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
|
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
|
||||||
|
|
||||||
if unique_workspace_ids:
|
if unique_workspace_ids:
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
for workspace_id in unique_workspace_ids:
|
for workspace_id in unique_workspace_ids:
|
||||||
@@ -1466,11 +1464,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
# 7. 构建最终结果
|
# 7. 构建最终结果
|
||||||
for end_user_id, data in user_data.items():
|
for end_user_id, data in user_data.items():
|
||||||
memory_config = None
|
memory_config = None
|
||||||
|
|
||||||
# 优先使用 end_user 直接分配的配置
|
# 优先使用 end_user 直接分配的配置
|
||||||
if data["memory_config_id"]:
|
if data["memory_config_id"]:
|
||||||
memory_config = config_id_to_config.get(data["memory_config_id"])
|
memory_config = config_id_to_config.get(data["memory_config_id"])
|
||||||
|
|
||||||
# 回退到工作空间默认配置
|
# 回退到工作空间默认配置
|
||||||
if not memory_config:
|
if not memory_config:
|
||||||
workspace_id = app_to_workspace.get(data["app_id"])
|
workspace_id = app_to_workspace.get(data["app_id"])
|
||||||
@@ -1486,4 +1484,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
|||||||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -107,6 +107,40 @@ def _validate_config_id(config_id, db: Session = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护
|
||||||
|
# 使用懒加载函数避免模块级循环导入
|
||||||
|
def _get_builtin_pruning_scenes() -> set:
|
||||||
|
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry
|
||||||
|
return set(SceneConfigRegistry.get_all_scenes())
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
||||||
|
"""当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
scene_id: 本体场景 UUID
|
||||||
|
pruning_scene: 语义剪枝场景名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class_name 字符串列表,或 None(内置场景 / 无数据时)
|
||||||
|
"""
|
||||||
|
if not scene_id:
|
||||||
|
return None
|
||||||
|
# 内置场景走 SceneConfigRegistry,不需要注入类型列表
|
||||||
|
if pruning_scene in _get_builtin_pruning_scenes():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||||
|
repo = OntologyClassRepository(db)
|
||||||
|
classes = repo.get_classes_by_scene(scene_id)
|
||||||
|
names = [c.class_name for c in classes if c.class_name]
|
||||||
|
return names if names else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfigService:
|
class MemoryConfigService:
|
||||||
"""
|
"""
|
||||||
Centralized service for memory configuration loading and validation.
|
Centralized service for memory configuration loading and validation.
|
||||||
@@ -359,6 +393,7 @@ class MemoryConfigService:
|
|||||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
|
ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene),
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
|
|||||||
@@ -1,45 +1,42 @@
|
|||||||
# 修改 memory_konwledges_server.py 文件
|
# 修改 memory_konwledges_server.py 文件
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from fastapi import HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.schemas import file_schema, document_schema
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
|
||||||
from app.models.document_model import Document
|
from app.models.document_model import Document
|
||||||
import uuid
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from fastapi import HTTPException, status
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
from app.schemas.file_schema import CustomTextFileCreate
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
from app.celery_app import celery_app
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
|
||||||
from app.db import get_db
|
|
||||||
# 创建一个简单的用户类用于测试
|
# 创建一个简单的用户类用于测试
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
|
||||||
class ChunkCreate(BaseModel):
|
class ChunkCreate(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class SimpleUser:
|
class SimpleUser:
|
||||||
def __init__(self, user_id: str):
|
def __init__(self, user_id: str):
|
||||||
# 确保ID是UUID类型
|
# 确保ID是UUID类型
|
||||||
self.id = user_id
|
self.id = user_id
|
||||||
self.username = user_id
|
self.username = user_id
|
||||||
|
|
||||||
'''解析'''
|
|
||||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||||
"""
|
"""
|
||||||
解析指定文档
|
解析指定文档
|
||||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
|||||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
'''获取块ID'''
|
|
||||||
async def get_document_chunks(
|
async def get_document_chunks(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
|||||||
|
|
||||||
return success(data=result, msg="文档块列表查询成功")
|
return success(data=result, msg="文档块列表查询成功")
|
||||||
|
|
||||||
'''查找文档ID'''
|
|
||||||
def find_document_id_by_kb_and_filename(
|
def find_document_id_by_kb_and_filename(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
'''获取知识库ID'''
|
|
||||||
def find_documents_by_kb_id(
|
def find_documents_by_kb_id(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
''''上传文件'''
|
|
||||||
async def memory_konwledges_up(
|
async def memory_konwledges_up(
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
parent_id: str,
|
parent_id: str,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session,
|
||||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
current_user: SimpleUser,
|
||||||
):
|
):
|
||||||
# 如果没有提供current_user,则创建一个默认的
|
|
||||||
if current_user is None:
|
|
||||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
|
||||||
|
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
print(f"file size: {file_size} byte")
|
print(f"file size: {file_size} byte")
|
||||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
|||||||
|
|
||||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||||
|
|
||||||
'''添加新块'''
|
|
||||||
|
|
||||||
|
|
||||||
async def create_document_chunk(
|
async def create_document_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -417,7 +408,7 @@ async def create_document_chunk(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"查询文档块失败: {error_msg}"
|
detail=f"查询文档块失败: {error_msg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
sort_id = sort_id + 1
|
sort_id = sort_id + 1
|
||||||
|
|
||||||
# 5. 创建文档块
|
# 5. 创建文档块
|
||||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
|||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return success(data=chunk, msg="文档块创建成功")
|
||||||
|
|
||||||
|
|
||||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
db_gen = get_db()
|
with get_db_context() as db:
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
try:
|
|
||||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||||
current_user = SimpleUser(user_rag_memory_id)
|
current_user = SimpleUser(user_rag_memory_id)
|
||||||
# 检查文档是否已存在
|
# 检查文档是否已存在
|
||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======',document)
|
print('======', document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
if document is not None:
|
if document is not None:
|
||||||
# 文档已存在,直接添加新块
|
# 文档已存在,直接添加新块
|
||||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
else:
|
else:
|
||||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||||
return result
|
return result
|
||||||
finally:
|
|
||||||
# 确保数据库会话被关闭
|
|
||||||
db.close()
|
|
||||||
@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# --- Create ---
|
# --- Create ---
|
||||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||||
|
# 业务层检查同一工作空间下是否已存在同名配置
|
||||||
|
if params.workspace_id and params.config_name:
|
||||||
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
existing = (
|
||||||
|
self.db.query(MemoryConfig)
|
||||||
|
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
|
||||||
|
|
||||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||||
configs = self._get_workspace_configs(params.workspace_id)
|
configs = self._get_workspace_configs(params.workspace_id)
|
||||||
@@ -135,6 +146,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
if not params.emotion_model_id:
|
if not params.emotion_model_id:
|
||||||
params.emotion_model_id = params.llm_id
|
params.emotion_model_id = params.llm_id
|
||||||
|
|
||||||
|
# 根据关联的本体场景推导 pruning_scene(语义剪枝场景与本体工程场景保持一致)
|
||||||
|
if params.scene_id and not getattr(params, 'pruning_scene', None):
|
||||||
|
params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id)
|
||||||
|
|
||||||
config = MemoryConfigRepository.create(self.db, params)
|
config = MemoryConfigRepository.create(self.db, params)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return {"affected": 1, "config_id": config.config_id}
|
return {"affected": 1, "config_id": config.config_id}
|
||||||
@@ -150,6 +165,23 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
finally:
|
finally:
|
||||||
db_session.close()
|
db_session.close()
|
||||||
|
|
||||||
|
def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]:
|
||||||
|
"""根据本体场景ID获取对应的 scene_name,作为语义剪枝场景值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 本体场景UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scene_name 字符串,查询失败时返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first()
|
||||||
|
return scene.scene_name if scene else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
# --- Delete ---
|
# --- Delete ---
|
||||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||||||
@@ -185,6 +217,19 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||||
|
|
||||||
|
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
||||||
|
needs_commit = False
|
||||||
|
for config, scene_name in results:
|
||||||
|
if scene_name and config.pruning_scene != scene_name:
|
||||||
|
logger.info(
|
||||||
|
f"修正 pruning_scene: config_id={config.config_id} "
|
||||||
|
f"'{config.pruning_scene}' -> '{scene_name}'"
|
||||||
|
)
|
||||||
|
config.pruning_scene = scene_name
|
||||||
|
needs_commit = True
|
||||||
|
if needs_commit:
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
# 将 ORM 对象转换为字典列表
|
# 将 ORM 对象转换为字典列表
|
||||||
data_list = []
|
data_list = []
|
||||||
for config, scene_name in results:
|
for config, scene_name in results:
|
||||||
@@ -211,6 +256,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||||
"scene_name": scene_name, # 新增:场景名称
|
"scene_name": scene_name, # 新增:场景名称
|
||||||
|
"is_system_default": config.is_default, # 是否为系统默认配置
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
@@ -737,8 +783,37 @@ async def analytics_hot_memory_tags(
|
|||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
stats, _msg = get_recent_activity_stats()
|
"""获取最近记忆提取活动统计。
|
||||||
|
|
||||||
|
优先从 Redis 缓存读取(按 workspace_id),缓存不存在时降级到日志文件解析。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID,用于从 Redis 读取对应缓存
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 total、stats、latest_relative、source 的统计字典
|
||||||
|
"""
|
||||||
|
stats = None
|
||||||
|
source = "log"
|
||||||
|
|
||||||
|
# 优先从 Redis 读取
|
||||||
|
if workspace_id:
|
||||||
|
try:
|
||||||
|
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||||
|
cached = await ActivityStatsCache.get_activity_stats(workspace_id)
|
||||||
|
if cached:
|
||||||
|
stats = cached.get("stats", {})
|
||||||
|
source = "redis"
|
||||||
|
logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}")
|
||||||
|
|
||||||
|
# 降级:从日志文件解析
|
||||||
|
if stats is None:
|
||||||
|
stats, _msg = get_recent_activity_stats()
|
||||||
|
source = "log"
|
||||||
|
|
||||||
total = (
|
total = (
|
||||||
stats.get("chunk_count", 0)
|
stats.get("chunk_count", 0)
|
||||||
+ stats.get("statements_count", 0)
|
+ stats.get("statements_count", 0)
|
||||||
@@ -746,26 +821,29 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
|||||||
+ stats.get("triplet_relations_count", 0)
|
+ stats.get("triplet_relations_count", 0)
|
||||||
+ stats.get("temporal_count", 0)
|
+ stats.get("temporal_count", 0)
|
||||||
)
|
)
|
||||||
# 精简:仅提供“最新一次活动多久前”
|
|
||||||
latest_relative = None
|
|
||||||
try:
|
|
||||||
info = stats.get("log_path", "")
|
|
||||||
idx = info.rfind("最新:")
|
|
||||||
if idx != -1:
|
|
||||||
latest_path = info[idx + 3 :].strip()
|
|
||||||
if latest_path and os.path.exists(latest_path):
|
|
||||||
import time
|
|
||||||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
|
||||||
m = int(diff // 60)
|
|
||||||
if m < 1:
|
|
||||||
latest_relative = "刚刚"
|
|
||||||
elif m < 60:
|
|
||||||
latest_relative = "一会前"
|
|
||||||
else:
|
|
||||||
latest_relative = "较早前"
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
|
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
||||||
|
latest_relative = None
|
||||||
|
if source == "log":
|
||||||
|
try:
|
||||||
|
info = stats.get("log_path", "")
|
||||||
|
idx = info.rfind("最新:")
|
||||||
|
if idx != -1:
|
||||||
|
latest_path = info[idx + 3:].strip()
|
||||||
|
if latest_path and os.path.exists(latest_path):
|
||||||
|
import time
|
||||||
|
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||||||
|
m = int(diff // 60)
|
||||||
|
if m < 1:
|
||||||
|
latest_relative = "刚刚"
|
||||||
|
elif m < 60:
|
||||||
|
latest_relative = "一会前"
|
||||||
|
else:
|
||||||
|
latest_relative = "较早前"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -116,27 +116,15 @@ class ModelConfigService:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# dashscope 的 omni 模型需要使用 compatible-mode
|
model_config = RedBearModelConfig(
|
||||||
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
|
model_name=model_name,
|
||||||
if not api_base:
|
provider=provider,
|
||||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
api_key=api_key,
|
||||||
model_config = RedBearModelConfig(
|
base_url=api_base,
|
||||||
model_name=model_name,
|
is_omni=is_omni,
|
||||||
provider=ModelProvider.OPENAI,
|
temperature=0.7,
|
||||||
api_key=api_key,
|
max_tokens=100
|
||||||
base_url=api_base,
|
)
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_config = RedBearModelConfig(
|
|
||||||
model_name=model_name,
|
|
||||||
provider=provider,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据模型类型选择不同的验证方式
|
# 根据模型类型选择不同的验证方式
|
||||||
model_type_lower = model_type.lower()
|
model_type_lower = model_type.lower()
|
||||||
@@ -492,6 +480,9 @@ class ModelApiKeyService:
|
|||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
data.is_omni = model_config.is_omni
|
||||||
|
data.capability = model_config.capability
|
||||||
|
|
||||||
# 从ModelBase获取model_name
|
# 从ModelBase获取model_name
|
||||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||||
@@ -550,8 +541,8 @@ class ModelApiKeyService:
|
|||||||
provider=data.provider,
|
provider=data.provider,
|
||||||
api_key=data.api_key,
|
api_key=data.api_key,
|
||||||
api_base=data.api_base,
|
api_base=data.api_base,
|
||||||
capability=data.capability if data.capability is not None else model_config.capability,
|
capability=data.capability,
|
||||||
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
|
is_omni=data.is_omni,
|
||||||
config=data.config,
|
config=data.config,
|
||||||
is_active=data.is_active,
|
is_active=data.is_active,
|
||||||
priority=data.priority
|
priority=data.priority
|
||||||
@@ -574,6 +565,10 @@ class ModelApiKeyService:
|
|||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
if api_key_data.is_omni is None:
|
||||||
|
api_key_data.is_omni = model_config.is_omni
|
||||||
|
if api_key_data.capability is None:
|
||||||
|
api_key_data.capability = model_config.capability
|
||||||
|
|
||||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||||
existing_key = db.query(ModelApiKey).join(
|
existing_key = db.query(ModelApiKey).join(
|
||||||
@@ -616,7 +611,7 @@ class ModelApiKeyService:
|
|||||||
api_base=api_key_data.api_base,
|
api_base=api_key_data.api_base,
|
||||||
model_type=model_config.type,
|
model_type=model_config.type,
|
||||||
test_message="Hello",
|
test_message="Hello",
|
||||||
is_omni=model_config.is_omni
|
is_omni=api_key_data.is_omni
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
if not validation_result["valid"]:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
@@ -785,6 +780,7 @@ class ModelBaseService:
|
|||||||
"description": model_base.description,
|
"description": model_base.description,
|
||||||
"capability": model_base.capability,
|
"capability": model_base.capability,
|
||||||
"is_omni": model_base.is_omni,
|
"is_omni": model_base.is_omni,
|
||||||
|
"is_active": False,
|
||||||
"is_composite": False
|
"is_composite": False
|
||||||
}
|
}
|
||||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||||
|
|||||||
@@ -326,6 +326,25 @@ async def run_pilot_extraction(
|
|||||||
|
|
||||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||||
|
|
||||||
|
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||||
|
try:
|
||||||
|
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||||
|
|
||||||
|
stats_to_cache = {
|
||||||
|
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
|
||||||
|
"statements_count": len(statement_nodes) if statement_nodes else 0,
|
||||||
|
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
|
||||||
|
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
|
||||||
|
"temporal_count": 0, # temporal 数据在日志中,此处暂置0
|
||||||
|
}
|
||||||
|
await ActivityStatsCache.set_activity_stats(
|
||||||
|
workspace_id=str(memory_config.workspace_id),
|
||||||
|
stats=stats_to_cache,
|
||||||
|
)
|
||||||
|
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||||
|
except Exception as cache_err:
|
||||||
|
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
||||||
from app.repositories.tool_repository import (
|
from app.repositories.tool_repository import (
|
||||||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||||||
@@ -79,6 +81,18 @@ class ToolService:
|
|||||||
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||||
return self._config_to_info(config) if config else None
|
return self._config_to_info(config) if config else None
|
||||||
|
|
||||||
|
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
||||||
|
"""检查工具名称是否重复"""
|
||||||
|
query = self.db.query(ToolConfig).filter(
|
||||||
|
ToolConfig.name == name,
|
||||||
|
ToolConfig.tool_type == tool_type,
|
||||||
|
ToolConfig.tenant_id == tenant_id
|
||||||
|
)
|
||||||
|
if exclude_id:
|
||||||
|
query = query.filter(ToolConfig.id != exclude_id)
|
||||||
|
if query.first():
|
||||||
|
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
def create_tool(
|
def create_tool(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -92,6 +106,7 @@ class ToolService:
|
|||||||
"""创建工具"""
|
"""创建工具"""
|
||||||
if tool_type == ToolType.BUILTIN:
|
if tool_type == ToolType.BUILTIN:
|
||||||
raise ValueError("内置工具不允许创建")
|
raise ValueError("内置工具不允许创建")
|
||||||
|
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建基础配置
|
# 创建基础配置
|
||||||
@@ -141,6 +156,7 @@ class ToolService:
|
|||||||
raise ValueError("内置工具不允许修改名称、描述和图标")
|
raise ValueError("内置工具不允许修改名称、描述和图标")
|
||||||
try:
|
try:
|
||||||
if name:
|
if name:
|
||||||
|
self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id)
|
||||||
config_obj.name = name
|
config_obj.name = name
|
||||||
if description:
|
if description:
|
||||||
config_obj.description = description
|
config_obj.description = description
|
||||||
@@ -894,7 +910,11 @@ class ToolService:
|
|||||||
config_data.update({
|
config_data.update({
|
||||||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||||||
"health_status": mcp_config.health_status,
|
"health_status": mcp_config.health_status,
|
||||||
"available_tools": available_tools_display
|
"available_tools": available_tools_display,
|
||||||
|
"source_channel": mcp_config.source_channel,
|
||||||
|
"market_id": mcp_config.market_id,
|
||||||
|
"market_config_id": mcp_config.market_config_id,
|
||||||
|
"mcp_service_id": mcp_config.mcp_service_id
|
||||||
})
|
})
|
||||||
|
|
||||||
return ToolInfo(
|
return ToolInfo(
|
||||||
@@ -949,7 +969,11 @@ class ToolService:
|
|||||||
id=tool_config.id,
|
id=tool_config.id,
|
||||||
server_url=config.get("server_url"),
|
server_url=config.get("server_url"),
|
||||||
connection_config=config.get("connection_config", {}),
|
connection_config=config.get("connection_config", {}),
|
||||||
available_tools=config.get("available_tools", [])
|
available_tools=config.get("available_tools", []),
|
||||||
|
source_channel=config.get("source_channel", "self_hosted"),
|
||||||
|
market_id=config.get("market_id"),
|
||||||
|
market_config_id=config.get("market_config_id"),
|
||||||
|
mcp_service_id=config.get("mcp_service_id"),
|
||||||
)
|
)
|
||||||
self.db.add(mcp_config)
|
self.db.add(mcp_config)
|
||||||
|
|
||||||
@@ -1002,6 +1026,14 @@ class ToolService:
|
|||||||
mcp_config.server_url = config.get("server_url")
|
mcp_config.server_url = config.get("server_url")
|
||||||
mcp_config.connection_config = config.get("connection_config", {})
|
mcp_config.connection_config = config.get("connection_config", {})
|
||||||
mcp_config.available_tools = config.get("available_tools", [])
|
mcp_config.available_tools = config.get("available_tools", [])
|
||||||
|
if config.get("source_channel") is not None:
|
||||||
|
mcp_config.source_channel = config.get("source_channel")
|
||||||
|
if config.get("market_id") is not None:
|
||||||
|
mcp_config.market_id = config.get("market_id")
|
||||||
|
if config.get("market_config_id") is not None:
|
||||||
|
mcp_config.market_config_id = config.get("market_config_id")
|
||||||
|
if config.get("mcp_service_id") is not None:
|
||||||
|
mcp_config.mcp_service_id = config.get("mcp_service_id")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
from app.services.memory_short_service import ShortService
|
from app.services.memory_short_service import ShortService
|
||||||
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
|
|
||||||
from app.core.language_utils import validate_language
|
from app.core.language_utils import validate_language
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||||
from app.db import get_db
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
# 验证语言参数
|
# 验证语言参数
|
||||||
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
# 获取数据库会话并查询用户信息
|
# 获取数据库会话并查询用户信息
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||||
if end_user and end_user.other_name:
|
if end_user and end_user.other_name:
|
||||||
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class WorkflowImportService:
|
|||||||
success=False,
|
success=False,
|
||||||
temp_id=None,
|
temp_id=None,
|
||||||
workflow_id=None,
|
workflow_id=None,
|
||||||
errors=[InvalidConfiguration()]
|
errors=[InvalidConfiguration()] + adapter.errors
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_config = adapter.parse_workflow()
|
workflow_config = adapter.parse_workflow()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest, FileInput
|
from app.schemas import DraftRunRequest, FileInput, FileType
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
@@ -496,6 +496,7 @@ class WorkflowService:
|
|||||||
"event": "start",
|
"event": "start",
|
||||||
"data": {
|
"data": {
|
||||||
"conversation_id": payload.get("conversation_id"),
|
"conversation_id": payload.get("conversation_id"),
|
||||||
|
"message_id": payload.get("message_id")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "workflow_end":
|
case "workflow_end":
|
||||||
@@ -600,6 +601,7 @@ class WorkflowService:
|
|||||||
try:
|
try:
|
||||||
files = await self._handle_file_input(payload.files)
|
files = await self._handle_file_input(payload.files)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
|
message_id = uuid.uuid4()
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
@@ -624,24 +626,45 @@ class WorkflowService:
|
|||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新执行结果
|
# 更新执行结果
|
||||||
if result.get("status") == "completed":
|
if result.get("status") == "completed":
|
||||||
token_usage = result.get("token_usage", {}) or {}
|
token_usage = result.get("token_usage", {}) or {}
|
||||||
|
|
||||||
|
final_messages = result.get("messages", [])[init_message_length:]
|
||||||
|
human_message = ""
|
||||||
|
assistant_message = ""
|
||||||
|
for message in final_messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
human_message += message["content"]
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
for file in message["content"]:
|
||||||
|
if file.get("type") == FileType.IMAGE:
|
||||||
|
human_message += f"})"
|
||||||
|
else:
|
||||||
|
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
assistant_message = message["content"]
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="user",
|
||||||
|
content=human_message,
|
||||||
|
meta_data=None
|
||||||
|
)
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
message_id=message_id,
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_message,
|
||||||
|
meta_data={"usage": token_usage}
|
||||||
|
)
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
"completed",
|
"completed",
|
||||||
output_data=result,
|
output_data=result,
|
||||||
token_usage=token_usage.get("total_tokens", None)
|
token_usage=token_usage.get("total_tokens", None)
|
||||||
)
|
)
|
||||||
final_messages = result.get("messages", [])[init_message_length:]
|
|
||||||
for message in final_messages:
|
|
||||||
self.conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id_uuid,
|
|
||||||
role=message["role"],
|
|
||||||
content=message["content"],
|
|
||||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
|
||||||
)
|
|
||||||
logger.info(f"Workflow Run Success, "
|
logger.info(f"Workflow Run Success, "
|
||||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
else:
|
else:
|
||||||
@@ -650,6 +673,8 @@ class WorkflowService:
|
|||||||
"failed",
|
"failed",
|
||||||
error_message=result.get("error")
|
error_message=result.get("error")
|
||||||
)
|
)
|
||||||
|
logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id},"
|
||||||
|
f" error: {result.get('error')}")
|
||||||
|
|
||||||
# 返回增强的响应结构
|
# 返回增强的响应结构
|
||||||
return {
|
return {
|
||||||
@@ -659,6 +684,7 @@ class WorkflowService:
|
|||||||
# "messages": result.get("messages"),
|
# "messages": result.get("messages"),
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
"message": result.get("output"), # 最终输出(字符串)
|
"message": result.get("output"), # 最终输出(字符串)
|
||||||
|
"message_id": str(message_id),
|
||||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
"error_message": result.get("error"),
|
"error_message": result.get("error"),
|
||||||
@@ -756,7 +782,7 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
message_id = uuid.uuid4()
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
@@ -765,24 +791,43 @@ class WorkflowService:
|
|||||||
user_id=payload.user_id,
|
user_id=payload.user_id,
|
||||||
):
|
):
|
||||||
if event.get("event") == "workflow_end":
|
if event.get("event") == "workflow_end":
|
||||||
|
|
||||||
status = event.get("data", {}).get("status")
|
status = event.get("data", {}).get("status")
|
||||||
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
token_usage = event.get("data", {}).get("token_usage", {}) or {}
|
||||||
if status == "completed":
|
if status == "completed":
|
||||||
|
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||||
|
human_message = ""
|
||||||
|
assistant_message = ""
|
||||||
|
for message in final_messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
human_message += message["content"]
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
for file in message["content"]:
|
||||||
|
if file.get("type") == FileType.IMAGE:
|
||||||
|
human_message += f"})"
|
||||||
|
else:
|
||||||
|
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
||||||
|
if message["role"] == "assistant":
|
||||||
|
assistant_message = message["content"]
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="user",
|
||||||
|
content=human_message,
|
||||||
|
meta_data=None
|
||||||
|
)
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
message_id=message_id,
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_message,
|
||||||
|
meta_data={"usage": token_usage}
|
||||||
|
)
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
"completed",
|
"completed",
|
||||||
output_data=event.get("data"),
|
output_data=event.get("data"),
|
||||||
token_usage=token_usage.get("total_tokens", None)
|
token_usage=token_usage.get("total_tokens", None)
|
||||||
)
|
)
|
||||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
|
||||||
for message in final_messages:
|
|
||||||
self.conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id_uuid,
|
|
||||||
role=message["role"],
|
|
||||||
content=message["content"],
|
|
||||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
|
||||||
)
|
|
||||||
logger.info(f"Workflow Run Success, "
|
logger.info(f"Workflow Run Success, "
|
||||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||||
elif status == "failed":
|
elif status == "failed":
|
||||||
@@ -793,6 +838,8 @@ class WorkflowService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"unexpect workflow run status, status: {status}")
|
logger.error(f"unexpect workflow run status, status: {status}")
|
||||||
|
elif event.get("event") == "workflow_start":
|
||||||
|
event["data"]["message_id"] = str(message_id)
|
||||||
event = self._emit(public, event)
|
event = self._emit(public, event)
|
||||||
if event:
|
if event:
|
||||||
yield event
|
yield event
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ import datetime
|
|||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from os import getenv
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||||
@@ -30,17 +30,15 @@ from app.schemas.workspace_schema import (
|
|||||||
WorkspaceModelsUpdate,
|
WorkspaceModelsUpdate,
|
||||||
WorkspaceUpdate,
|
WorkspaceUpdate,
|
||||||
)
|
)
|
||||||
from app.config.default_ontology_initializer import DefaultOntologyInitializer
|
|
||||||
|
|
||||||
# 获取业务逻辑专用日志器
|
# 获取业务逻辑专用日志器
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
def switch_workspace(
|
def switch_workspace(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user: User,
|
user: User,
|
||||||
):
|
):
|
||||||
"""切换工作空间"""
|
"""切换工作空间"""
|
||||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
@@ -60,31 +58,32 @@ def switch_workspace(
|
|||||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def delete_workspace_member(
|
def delete_workspace_member(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
user: User,
|
user: User,
|
||||||
):
|
):
|
||||||
"""删除工作空间成员"""
|
"""删除工作空间成员"""
|
||||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||||
if not workspace_member:
|
if not workspace_member:
|
||||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
if workspace_member.workspace_id != workspace_id:
|
if workspace_member.workspace_id != workspace_id:
|
||||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
|
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}",
|
||||||
|
BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace_member.is_active = False
|
workspace_member.is_active = False
|
||||||
workspace_member.user.current_workspace_id = None
|
workspace_member.user.current_workspace_id = None
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||||
@@ -102,18 +101,19 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
|||||||
"""
|
"""
|
||||||
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
|
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
|
||||||
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||||
|
|
||||||
# Ensure each neo4j workspace has a default memory config
|
# Ensure each neo4j workspace has a default memory config
|
||||||
for workspace in workspaces:
|
for workspace in workspaces:
|
||||||
if workspace.storage_type == 'neo4j':
|
if workspace.storage_type == 'neo4j':
|
||||||
_ensure_default_memory_config(db, workspace)
|
_ensure_default_memory_config(db, workspace)
|
||||||
|
_ensure_default_ontology_scenes(db, workspace)
|
||||||
|
|
||||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||||
return workspaces
|
return workspaces
|
||||||
|
|
||||||
|
|
||||||
def _create_workspace_only(
|
def _create_workspace_only(
|
||||||
db: Session, workspace: WorkspaceCreate, owner: User
|
db: Session, workspace: WorkspaceCreate, owner: User
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||||
|
|
||||||
@@ -129,6 +129,7 @@ def _create_workspace_only(
|
|||||||
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
|
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def create_workspace(
|
def create_workspace(
|
||||||
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
|
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
@@ -136,9 +137,14 @@ def create_workspace(
|
|||||||
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
|
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
|
||||||
f"storage_type: {workspace.storage_type}"
|
f"storage_type: {workspace.storage_type}"
|
||||||
)
|
)
|
||||||
llm=workspace.llm
|
if workspace_repository.get_workspaces_by_name(db=db, name=workspace.name, tenant_id=user.tenant_id):
|
||||||
embedding=workspace.embedding
|
raise BusinessException(
|
||||||
rerank=workspace.rerank
|
message="同名工作空间已存在",
|
||||||
|
code=BizCode.RESOURCE_ALREADY_EXISTS
|
||||||
|
)
|
||||||
|
llm = workspace.llm
|
||||||
|
embedding = workspace.embedding
|
||||||
|
rerank = workspace.rerank
|
||||||
try:
|
try:
|
||||||
# Create the workspace without adding any members
|
# Create the workspace without adding any members
|
||||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||||
@@ -151,33 +157,35 @@ def create_workspace(
|
|||||||
|
|
||||||
# Initialize default ontology scenes for the workspace (先创建本体场景)
|
# Initialize default ontology scenes for the workspace (先创建本体场景)
|
||||||
default_scene_id = None
|
default_scene_id = None
|
||||||
|
default_scene_name = None
|
||||||
try:
|
try:
|
||||||
initializer = DefaultOntologyInitializer(db)
|
initializer = DefaultOntologyInitializer(db)
|
||||||
success, error_msg = initializer.initialize_default_scenes(
|
success, error_msg = initializer.initialize_default_scenes(
|
||||||
db_workspace.id, language=language
|
db_workspace.id, language=language
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
|
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
# 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
|
||||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||||
from app.config.default_ontology_config import (
|
from app.config.default_ontology_config import (
|
||||||
ONLINE_EDUCATION_SCENE,
|
ONLINE_EDUCATION_SCENE,
|
||||||
EMOTIONAL_COMPANION_SCENE,
|
EMOTIONAL_COMPANION_SCENE,
|
||||||
get_scene_name
|
get_scene_name
|
||||||
)
|
)
|
||||||
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
scene_repo = OntologySceneRepository(db)
|
||||||
|
|
||||||
# 优先尝试获取教育场景
|
# 优先尝试获取教育场景
|
||||||
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
||||||
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
|
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
|
||||||
|
|
||||||
if education_scene:
|
if education_scene:
|
||||||
default_scene_id = education_scene.scene_id
|
default_scene_id = education_scene.scene_id
|
||||||
|
default_scene_name = education_scene.scene_name
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})"
|
f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})"
|
||||||
)
|
)
|
||||||
@@ -185,9 +193,10 @@ def create_workspace(
|
|||||||
# 如果教育场景不存在,尝试获取情感陪伴场景
|
# 如果教育场景不存在,尝试获取情感陪伴场景
|
||||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||||
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
|
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
|
||||||
|
|
||||||
if companion_scene:
|
if companion_scene:
|
||||||
default_scene_id = companion_scene.scene_id
|
default_scene_id = companion_scene.scene_id
|
||||||
|
default_scene_name = companion_scene.scene_name
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})"
|
f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})"
|
||||||
)
|
)
|
||||||
@@ -218,6 +227,7 @@ def create_workspace(
|
|||||||
embedding_id=embedding,
|
embedding_id=embedding,
|
||||||
rerank_id=rerank,
|
rerank_id=rerank,
|
||||||
scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景)
|
scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景)
|
||||||
|
pruning_scene_name=default_scene_name, # 传入场景名称作为语义剪枝场景值
|
||||||
)
|
)
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})"
|
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})"
|
||||||
@@ -250,10 +260,10 @@ def create_workspace(
|
|||||||
avatar='',
|
avatar='',
|
||||||
type=KnowledgeType.General,
|
type=KnowledgeType.General,
|
||||||
permission_id=PermissionType.Memory,
|
permission_id=PermissionType.Memory,
|
||||||
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
|
embedding_id=embedding,
|
||||||
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
|
reranker_id=rerank,
|
||||||
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
llm_id=llm,
|
||||||
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
image2text_id=llm,
|
||||||
parser_config={
|
parser_config={
|
||||||
"layout_recognize": "DeepDOC",
|
"layout_recognize": "DeepDOC",
|
||||||
"chunk_token_num": 256,
|
"chunk_token_num": 256,
|
||||||
@@ -288,7 +298,7 @@ def create_workspace(
|
|||||||
business_logger.info(
|
business_logger.info(
|
||||||
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
|
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
|
||||||
)
|
)
|
||||||
|
|
||||||
return db_workspace
|
return db_workspace
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -298,11 +308,11 @@ def create_workspace(
|
|||||||
|
|
||||||
|
|
||||||
def update_workspace(
|
def update_workspace(
|
||||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
|
||||||
try:
|
try:
|
||||||
# 更新工作空间
|
# 更新工作空间
|
||||||
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
|
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
|
||||||
@@ -322,7 +332,7 @@ def update_workspace(
|
|||||||
|
|
||||||
|
|
||||||
def get_workspace_members(
|
def get_workspace_members(
|
||||||
db: Session, workspace_id: uuid.UUID, user: User
|
db: Session, workspace_id: uuid.UUID, user: User
|
||||||
) -> List[WorkspaceMember]:
|
) -> List[WorkspaceMember]:
|
||||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
@@ -366,7 +376,6 @@ def get_workspace_members(
|
|||||||
return members
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 邀请相关服务方法 ====================
|
# ==================== 邀请相关服务方法 ====================
|
||||||
|
|
||||||
def _generate_invite_token() -> tuple[str, str]:
|
def _generate_invite_token() -> tuple[str, str]:
|
||||||
@@ -459,13 +468,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
|||||||
|
|
||||||
|
|
||||||
def create_workspace_invite(
|
def create_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
invite_data: WorkspaceInviteCreate,
|
invite_data: WorkspaceInviteCreate,
|
||||||
user: User
|
user: User
|
||||||
) -> WorkspaceInviteResponse:
|
) -> WorkspaceInviteResponse:
|
||||||
"""创建工作空间邀请"""
|
"""创建工作空间邀请"""
|
||||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
business_logger.info(
|
||||||
|
f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查权限
|
# 检查权限
|
||||||
@@ -528,17 +538,18 @@ def create_workspace_invite(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
business_logger.error(
|
||||||
|
f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_invites(
|
def get_workspace_invites(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user: User,
|
user: User,
|
||||||
status: Optional[InviteStatus] = None,
|
status: Optional[InviteStatus] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0
|
offset: int = 0
|
||||||
) -> List[WorkspaceInviteResponse]:
|
) -> List[WorkspaceInviteResponse]:
|
||||||
"""获取工作空间邀请列表"""
|
"""获取工作空间邀请列表"""
|
||||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
@@ -599,9 +610,9 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
|||||||
|
|
||||||
|
|
||||||
def accept_workspace_invite(
|
def accept_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
accept_request: InviteAcceptRequest,
|
accept_request: InviteAcceptRequest,
|
||||||
user: User
|
user: User
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""接受工作空间邀请"""
|
"""接受工作空间邀请"""
|
||||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||||
@@ -689,7 +700,8 @@ def accept_workspace_invite(
|
|||||||
# 获取工作空间信息
|
# 获取工作空间信息
|
||||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||||
|
|
||||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
business_logger.info(
|
||||||
|
f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "Successfully joined the workspace",
|
"message": "Successfully joined the workspace",
|
||||||
@@ -704,13 +716,14 @@ def accept_workspace_invite(
|
|||||||
|
|
||||||
|
|
||||||
def revoke_workspace_invite(
|
def revoke_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
invite_id: uuid.UUID,
|
invite_id: uuid.UUID,
|
||||||
user: User
|
user: User
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""撤销工作空间邀请"""
|
"""撤销工作空间邀请"""
|
||||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
business_logger.info(
|
||||||
|
f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查权限
|
# 检查权限
|
||||||
@@ -739,13 +752,14 @@ def revoke_workspace_invite(
|
|||||||
|
|
||||||
|
|
||||||
def update_workspace_member_roles(
|
def update_workspace_member_roles(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
updates: List[WorkspaceMemberUpdate],
|
updates: List[WorkspaceMemberUpdate],
|
||||||
user: User,
|
user: User,
|
||||||
) -> List[WorkspaceMember]:
|
) -> List[WorkspaceMember]:
|
||||||
"""更新工作空间成员角色"""
|
"""更新工作空间成员角色"""
|
||||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
business_logger.info(
|
||||||
|
f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||||
|
|
||||||
# 检查管理员权限
|
# 检查管理员权限
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
@@ -759,7 +773,8 @@ def update_workspace_member_roles(
|
|||||||
for upd in updates:
|
for upd in updates:
|
||||||
# 检查成员是否存在
|
# 检查成员是否存在
|
||||||
if upd.id not in member_map:
|
if upd.id not in member_map:
|
||||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}",
|
||||||
|
BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||||
|
|
||||||
member = member_map[upd.id]
|
member = member_map[upd.id]
|
||||||
|
|
||||||
@@ -911,10 +926,10 @@ def get_workspace_models_configs(
|
|||||||
|
|
||||||
|
|
||||||
def update_workspace_models_configs(
|
def update_workspace_models_configs(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
models_update: WorkspaceModelsUpdate,
|
models_update: WorkspaceModelsUpdate,
|
||||||
user: User,
|
user: User,
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
"""更新工作空间的模型配置(llm, embedding, rerank)
|
"""更新工作空间的模型配置(llm, embedding, rerank)
|
||||||
|
|
||||||
@@ -961,88 +976,9 @@ def update_workspace_models_configs(
|
|||||||
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
|
||||||
"""Ensure a workspace has a default memory config, creating one if missing.
|
|
||||||
|
|
||||||
Also fills empty model fields for all configs in this workspace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
workspace: The workspace to check
|
|
||||||
"""
|
|
||||||
from app.models.memory_config_model import MemoryConfig
|
|
||||||
|
|
||||||
# Check if default config exists for this workspace
|
|
||||||
existing_default = db.query(MemoryConfig).filter(
|
|
||||||
MemoryConfig.workspace_id == workspace.id,
|
|
||||||
MemoryConfig.is_default == True
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not existing_default:
|
|
||||||
# No default config exists, create one
|
|
||||||
business_logger.info(
|
|
||||||
f"Workspace {workspace.id} missing default memory config, creating one"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景
|
|
||||||
default_scene_id = None
|
|
||||||
try:
|
|
||||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
|
||||||
from app.config.default_ontology_config import (
|
|
||||||
ONLINE_EDUCATION_SCENE,
|
|
||||||
EMOTIONAL_COMPANION_SCENE,
|
|
||||||
get_scene_name
|
|
||||||
)
|
|
||||||
|
|
||||||
scene_repo = OntologySceneRepository(db)
|
|
||||||
# 尝试中文和英文场景名称
|
|
||||||
for language in ["zh", "en"]:
|
|
||||||
# 优先尝试教育场景
|
|
||||||
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
|
||||||
education_scene = scene_repo.get_by_name(education_scene_name, workspace.id)
|
|
||||||
if education_scene:
|
|
||||||
default_scene_id = education_scene.scene_id
|
|
||||||
business_logger.info(
|
|
||||||
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果教育场景不存在,尝试情感陪伴场景
|
|
||||||
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
|
||||||
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
|
|
||||||
if companion_scene:
|
|
||||||
default_scene_id = companion_scene.scene_id
|
|
||||||
business_logger.info(
|
|
||||||
f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as scene_error:
|
|
||||||
business_logger.warning(
|
|
||||||
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
_create_default_memory_config(
|
|
||||||
db=db,
|
|
||||||
workspace_id=workspace.id,
|
|
||||||
workspace_name=workspace.name,
|
|
||||||
llm_id=uuid.UUID(workspace.llm) if workspace.llm else None,
|
|
||||||
embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None,
|
|
||||||
rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None,
|
|
||||||
scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
business_logger.error(
|
|
||||||
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill empty model fields for ALL configs in this workspace
|
|
||||||
_fill_workspace_configs_model_defaults(db, workspace)
|
|
||||||
|
|
||||||
|
|
||||||
def _fill_workspace_configs_model_defaults(
|
def _fill_workspace_configs_model_defaults(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace: Workspace
|
workspace: Workspace
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fill empty model fields for all memory configs in a workspace.
|
"""Fill empty model fields for all memory configs in a workspace.
|
||||||
|
|
||||||
@@ -1054,43 +990,43 @@ def _fill_workspace_configs_model_defaults(
|
|||||||
workspace: The workspace containing default model settings
|
workspace: The workspace containing default model settings
|
||||||
"""
|
"""
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
|
||||||
# Get all configs for this workspace
|
# Get all configs for this workspace
|
||||||
configs = db.query(MemoryConfig).filter(
|
configs = db.query(MemoryConfig).filter(
|
||||||
MemoryConfig.workspace_id == workspace.id
|
MemoryConfig.workspace_id == workspace.id
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
if not configs:
|
if not configs:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Map of memory_config field -> workspace field
|
# Map of memory_config field -> workspace field
|
||||||
model_field_mappings = [
|
model_field_mappings = [
|
||||||
("llm_id", "llm"),
|
("llm_id", "llm"),
|
||||||
("embedding_id", "embedding"),
|
("embedding_id", "embedding"),
|
||||||
("rerank_id", "rerank"),
|
("rerank_id", "rerank"),
|
||||||
("reflection_model_id", "llm"), # reflection uses LLM
|
("reflection_model_id", "llm"), # reflection uses LLM
|
||||||
("emotion_model_id", "llm"), # emotion uses LLM
|
("emotion_model_id", "llm"), # emotion uses LLM
|
||||||
]
|
]
|
||||||
|
|
||||||
configs_updated = 0
|
configs_updated = 0
|
||||||
|
|
||||||
for memory_config in configs:
|
for memory_config in configs:
|
||||||
updated_fields = []
|
updated_fields = []
|
||||||
|
|
||||||
for config_field, workspace_field in model_field_mappings:
|
for config_field, workspace_field in model_field_mappings:
|
||||||
config_value = getattr(memory_config, config_field, None)
|
config_value = getattr(memory_config, config_field, None)
|
||||||
workspace_value = getattr(workspace, workspace_field, None)
|
workspace_value = getattr(workspace, workspace_field, None)
|
||||||
|
|
||||||
if not config_value and workspace_value:
|
if not config_value and workspace_value:
|
||||||
setattr(memory_config, config_field, workspace_value)
|
setattr(memory_config, config_field, workspace_value)
|
||||||
updated_fields.append(config_field)
|
updated_fields.append(config_field)
|
||||||
|
|
||||||
if updated_fields:
|
if updated_fields:
|
||||||
configs_updated += 1
|
configs_updated += 1
|
||||||
business_logger.debug(
|
business_logger.debug(
|
||||||
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
f"Updated memory config {memory_config.config_id} fields: {updated_fields}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if configs_updated > 0:
|
if configs_updated > 0:
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -1105,13 +1041,14 @@ def _fill_workspace_configs_model_defaults(
|
|||||||
|
|
||||||
|
|
||||||
def _create_default_memory_config(
|
def _create_default_memory_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
workspace_name: str,
|
workspace_name: str,
|
||||||
llm_id: Optional[uuid.UUID] = None,
|
llm_id: Optional[uuid.UUID] = None,
|
||||||
embedding_id: Optional[uuid.UUID] = None,
|
embedding_id: Optional[uuid.UUID] = None,
|
||||||
rerank_id: Optional[uuid.UUID] = None,
|
rerank_id: Optional[uuid.UUID] = None,
|
||||||
scene_id: Optional[uuid.UUID] = None,
|
scene_id: Optional[uuid.UUID] = None,
|
||||||
|
pruning_scene_name: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a default memory config for a newly created workspace.
|
"""Create a default memory config for a newly created workspace.
|
||||||
|
|
||||||
@@ -1123,11 +1060,12 @@ def _create_default_memory_config(
|
|||||||
embedding_id: Optional embedding model ID
|
embedding_id: Optional embedding model ID
|
||||||
rerank_id: Optional rerank model ID
|
rerank_id: Optional rerank model ID
|
||||||
scene_id: Optional ontology scene ID (默认关联教育场景)
|
scene_id: Optional ontology scene ID (默认关联教育场景)
|
||||||
|
pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name
|
||||||
"""
|
"""
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
|
||||||
config_id = uuid.uuid4()
|
config_id = uuid.uuid4()
|
||||||
|
|
||||||
default_config = MemoryConfig(
|
default_config = MemoryConfig(
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
config_name=f"{workspace_name} 默认配置",
|
config_name=f"{workspace_name} 默认配置",
|
||||||
@@ -1136,14 +1074,15 @@ def _create_default_memory_config(
|
|||||||
llm_id=str(llm_id) if llm_id else None,
|
llm_id=str(llm_id) if llm_id else None,
|
||||||
embedding_id=str(embedding_id) if embedding_id else None,
|
embedding_id=str(embedding_id) if embedding_id else None,
|
||||||
rerank_id=str(rerank_id) if rerank_id else None,
|
rerank_id=str(rerank_id) if rerank_id else None,
|
||||||
scene_id=scene_id, # 关联本体场景ID
|
scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景)
|
||||||
|
pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name
|
||||||
state=True, # Active by default
|
state=True, # Active by default
|
||||||
is_default=True, # Mark as workspace default
|
is_default=True, # Mark as workspace default
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(default_config)
|
db.add(default_config)
|
||||||
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
db.flush() # 使用 flush 而不是 commit,让调用者统一提交
|
||||||
|
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
"Created default memory config for workspace",
|
"Created default memory config for workspace",
|
||||||
extra={
|
extra={
|
||||||
@@ -1153,3 +1092,130 @@ def _create_default_memory_config(
|
|||||||
"scene_id": str(scene_id) if scene_id else None,
|
"scene_id": str(scene_id) if scene_id else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 检查配置相关服务 ====================
|
||||||
|
|
||||||
|
def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
|
||||||
|
"""Ensure a workspace has a default memory config, creating one if missing.
|
||||||
|
|
||||||
|
Also fills empty model fields for all configs in this workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
workspace: The workspace to check
|
||||||
|
"""
|
||||||
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
|
||||||
|
# Check if default config exists for this workspace
|
||||||
|
existing_default = db.query(MemoryConfig).filter(
|
||||||
|
MemoryConfig.workspace_id == workspace.id,
|
||||||
|
MemoryConfig.is_default == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not existing_default:
|
||||||
|
# No default config exists, create one
|
||||||
|
business_logger.info(
|
||||||
|
f"Workspace {workspace.id} missing default memory config, creating one"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景
|
||||||
|
default_scene_id = None
|
||||||
|
try:
|
||||||
|
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||||
|
from app.config.default_ontology_config import (
|
||||||
|
ONLINE_EDUCATION_SCENE,
|
||||||
|
EMOTIONAL_COMPANION_SCENE,
|
||||||
|
get_scene_name
|
||||||
|
)
|
||||||
|
|
||||||
|
scene_repo = OntologySceneRepository(db)
|
||||||
|
# 尝试中文和英文场景名称
|
||||||
|
for language in ["zh", "en"]:
|
||||||
|
# 优先尝试教育场景
|
||||||
|
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
|
||||||
|
education_scene = scene_repo.get_by_name(education_scene_name, workspace.id)
|
||||||
|
if education_scene:
|
||||||
|
default_scene_id = education_scene.scene_id
|
||||||
|
business_logger.info(
|
||||||
|
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果教育场景不存在,尝试情感陪伴场景
|
||||||
|
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
|
||||||
|
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
|
||||||
|
if companion_scene:
|
||||||
|
default_scene_id = companion_scene.scene_id
|
||||||
|
business_logger.info(
|
||||||
|
f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as scene_error:
|
||||||
|
business_logger.warning(
|
||||||
|
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_create_default_memory_config(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
workspace_name=workspace.name,
|
||||||
|
llm_id=uuid.UUID(workspace.llm) if workspace.llm else None,
|
||||||
|
embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None,
|
||||||
|
rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None,
|
||||||
|
scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(
|
||||||
|
f"Failed to create default memory config for workspace {workspace.id}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill empty model fields for ALL configs in this workspace
|
||||||
|
_fill_workspace_configs_model_defaults(db, workspace)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||||
|
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||||
|
|
||||||
|
Checks whether any is_system_default scene exists for the workspace.
|
||||||
|
If not, runs the DefaultOntologyInitializer to create them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
workspace: The workspace to check
|
||||||
|
"""
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
|
# 幂等检查:是否已存在系统默认场景
|
||||||
|
existing = db.query(OntologyScene).filter(
|
||||||
|
OntologyScene.workspace_id == workspace.id,
|
||||||
|
OntologyScene.is_system_default.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return
|
||||||
|
|
||||||
|
business_logger.info(
|
||||||
|
f"Workspace {workspace.id} missing default ontology scenes, creating them"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
initializer = DefaultOntologyInitializer(db)
|
||||||
|
success, error_msg = initializer.initialize_default_scenes(
|
||||||
|
workspace.id, language="zh"
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
db.commit()
|
||||||
|
business_logger.info(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
business_logger.warning(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
business_logger.error(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
376
api/app/tasks.py
376
api/app/tasks.py
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -14,6 +15,62 @@ from uuid import UUID
|
|||||||
|
|
||||||
import redis
|
import redis
|
||||||
import requests
|
import requests
|
||||||
|
from redis.exceptions import RedisError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||||
|
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||||
|
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||||
|
_sync_redis_pool: redis.ConnectionPool = None
|
||||||
|
|
||||||
|
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||||
|
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||||
|
global _sync_redis_pool
|
||||||
|
if _sync_redis_pool is None:
|
||||||
|
try:
|
||||||
|
_sync_redis_pool = redis.ConnectionPool(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB_CELERY_BACKEND,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=10,
|
||||||
|
socket_connect_timeout=5,
|
||||||
|
socket_timeout=5,
|
||||||
|
retry_on_timeout=True,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
logger.info("Redis connection pool created for Celery tasks")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create Redis connection pool: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
return _sync_redis_pool
|
||||||
|
|
||||||
|
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||||
|
"""获取同步 Redis 客户端(使用连接池)
|
||||||
|
|
||||||
|
使用连接池提供的客户端,支持自动重连和健康检查。
|
||||||
|
如果 Redis 不可用,返回 None,调用方应优雅降级。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
redis.StrictRedis: Redis 客户端实例,如果连接失败则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pool = _get_or_create_redis_pool()
|
||||||
|
if pool is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
client = redis.StrictRedis(connection_pool=pool)
|
||||||
|
# 验证连接可用性
|
||||||
|
client.ping()
|
||||||
|
return client
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis connection failed: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
# Import a unified Celery instance
|
# Import a unified Celery instance
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
@@ -1090,6 +1147,22 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||||
|
|
||||||
|
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||||
|
try:
|
||||||
|
_r = get_sync_redis_client()
|
||||||
|
if _r is not None:
|
||||||
|
from datetime import timedelta as _td
|
||||||
|
from datetime import timezone as _tz
|
||||||
|
_CST = _tz(_td(hours=8))
|
||||||
|
_now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat()
|
||||||
|
_r.set(
|
||||||
|
f"write_message:last_done:{end_user_id}",
|
||||||
|
_now_cst,
|
||||||
|
ex=86400 * 30,
|
||||||
|
)
|
||||||
|
except Exception as _e:
|
||||||
|
logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
@@ -2149,12 +2222,16 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
|
||||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||||
from sqlalchemy import select, func
|
from app.repositories.implicit_emotions_storage_repository import (
|
||||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
ImplicitEmotionsStorageRepository,
|
||||||
|
TimeFilterUnavailableError,
|
||||||
|
)
|
||||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||||
|
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||||
@@ -2167,18 +2244,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
# 获取所有已存储数据的用户ID(分批次处理)
|
|
||||||
repo = ImplicitEmotionsStorageRepository(db)
|
repo = ImplicitEmotionsStorageRepository(db)
|
||||||
|
|
||||||
# 先统计总数用于日志
|
# 先统计总数用于日志
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
total_users = db.execute(
|
total_users = db.execute(
|
||||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||||
).scalar() or 0
|
).scalar() or 0
|
||||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
logger.info(f"表中存量用户总数: {total_users},开始时间轴筛选")
|
||||||
|
|
||||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
# 构建 Redis 同步客户端,用于时间轴筛选
|
||||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
_redis_client = get_sync_redis_client()
|
||||||
|
|
||||||
|
# 只处理 last_done > updated_at 的用户(有新记忆写入的用户)
|
||||||
|
# Redis 不可用时回退到全量处理
|
||||||
|
try:
|
||||||
|
refresh_iter = repo.get_users_needing_refresh(_redis_client, batch_size=100)
|
||||||
|
except TimeFilterUnavailableError as e:
|
||||||
|
logger.warning(f"时间轴筛选不可用,回退到全量刷新: {e}")
|
||||||
|
refresh_iter = repo.get_all_user_ids(batch_size=100)
|
||||||
|
|
||||||
|
for end_user_id in refresh_iter:
|
||||||
logger.info(f"开始处理用户: {end_user_id}")
|
logger.info(f"开始处理用户: {end_user_id}")
|
||||||
user_start_time = time.time()
|
user_start_time = time.time()
|
||||||
|
|
||||||
@@ -2264,10 +2350,10 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
user_results.append(error_info)
|
user_results.append(error_info)
|
||||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||||
|
|
||||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
# ---- 当天新增用户兜底初始化 ----
|
||||||
new_users_initialized = 0
|
new_users_initialized = 0
|
||||||
new_users_failed = 0
|
new_users_failed = 0
|
||||||
logger.info("开始处理当天新增的增量用户初始化")
|
logger.info("开始处理当天新增用户的兜底初始化")
|
||||||
|
|
||||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||||
@@ -2281,35 +2367,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||||
await implicit_service.save_profile_cache(
|
await implicit_service.save_profile_cache(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||||
profile_data=profile_data,
|
|
||||||
db=db
|
|
||||||
)
|
)
|
||||||
implicit_success = True
|
implicit_success = True
|
||||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
errors.append(f"隐性记忆初始化失败: {str(e)}")
|
||||||
errors.append(error_msg)
|
logger.error(f"新用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
emotion_service = EmotionAnalyticsService()
|
emotion_service = EmotionAnalyticsService()
|
||||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id, db=db, language="zh"
|
||||||
db=db,
|
|
||||||
language="zh"
|
|
||||||
)
|
)
|
||||||
await emotion_service.save_suggestions_cache(
|
await emotion_service.save_suggestions_cache(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||||
suggestions_data=suggestions_data,
|
|
||||||
db=db
|
|
||||||
)
|
)
|
||||||
emotion_success = True
|
emotion_success = True
|
||||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
errors.append(f"情绪建议初始化失败: {str(e)}")
|
||||||
errors.append(error_msg)
|
logger.error(f"新用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
|
||||||
|
|
||||||
if implicit_success or emotion_success:
|
if implicit_success or emotion_success:
|
||||||
new_users_initialized += 1
|
new_users_initialized += 1
|
||||||
@@ -2319,7 +2397,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
user_elapsed = time.time() - user_start_time
|
user_elapsed = time.time() - user_start_time
|
||||||
user_results.append({
|
user_results.append({
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"type": "init",
|
"type": "new_user_init",
|
||||||
"implicit_success": implicit_success,
|
"implicit_success": implicit_success,
|
||||||
"emotion_success": emotion_success,
|
"emotion_success": emotion_success,
|
||||||
"errors": errors,
|
"errors": errors,
|
||||||
@@ -2331,7 +2409,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
user_elapsed = time.time() - user_start_time
|
user_elapsed = time.time() - user_start_time
|
||||||
user_results.append({
|
user_results.append({
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"type": "init",
|
"type": "new_user_init",
|
||||||
"implicit_success": False,
|
"implicit_success": False,
|
||||||
"emotion_success": False,
|
"emotion_success": False,
|
||||||
"errors": [str(e)],
|
"errors": [str(e)],
|
||||||
@@ -2339,27 +2417,24 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
})
|
})
|
||||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"当天新增用户兜底初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}")
|
||||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
# ---- 新增用户兜底初始化结束 ----
|
||||||
)
|
|
||||||
# ---- 增量用户处理结束 ----
|
|
||||||
|
|
||||||
# 记录总体统计信息
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||||
f"存量用户总数={total_users}, "
|
f"存量用户总数={total_users}, "
|
||||||
f"隐性记忆成功={successful_implicit}, "
|
f"隐性记忆成功={successful_implicit}, "
|
||||||
f"情绪建议成功={successful_emotion}, "
|
f"情绪建议成功={successful_emotion}, "
|
||||||
f"存量失败={failed}, "
|
f"存量失败={failed}, "
|
||||||
f"增量初始化成功={new_users_initialized}, "
|
f"新增用户初始化成功={new_users_initialized}, "
|
||||||
f"增量初始化失败={new_users_failed}"
|
f"新增用户初始化失败={new_users_failed}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"message": (
|
"message": (
|
||||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
f"当天新增用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||||
),
|
),
|
||||||
"total_users": total_users,
|
"total_users": total_users,
|
||||||
"successful_implicit": successful_implicit,
|
"successful_implicit": successful_implicit,
|
||||||
@@ -2367,7 +2442,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
"failed": failed,
|
"failed": failed,
|
||||||
"new_users_initialized": new_users_initialized,
|
"new_users_initialized": new_users_initialized,
|
||||||
"new_users_failed": new_users_failed,
|
"new_users_failed": new_users_failed,
|
||||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
"user_results": user_results[:50]
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -2416,3 +2491,232 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
|||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
name="app.tasks.init_implicit_emotions_for_users",
|
||||||
|
bind=True,
|
||||||
|
ignore_result=True,
|
||||||
|
max_retries=0,
|
||||||
|
acks_late=False,
|
||||||
|
time_limit=3600,
|
||||||
|
soft_time_limit=3300,
|
||||||
|
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
|
||||||
|
triggered=True,
|
||||||
|
)
|
||||||
|
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||||
|
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
|
||||||
|
|
||||||
|
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
|
||||||
|
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_ids: 需要检查的用户ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含任务执行结果的字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.repositories.implicit_emotions_storage_repository import (
|
||||||
|
ImplicitEmotionsStorageRepository,
|
||||||
|
)
|
||||||
|
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||||
|
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||||
|
|
||||||
|
initialized = 0
|
||||||
|
failed = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
repo = ImplicitEmotionsStorageRepository(db)
|
||||||
|
|
||||||
|
for end_user_id in end_user_ids:
|
||||||
|
existing = repo.get_by_end_user_id(end_user_id)
|
||||||
|
if existing is not None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
|
||||||
|
implicit_ok = False
|
||||||
|
emotion_ok = False
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||||
|
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||||
|
await implicit_service.save_profile_cache(
|
||||||
|
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||||
|
)
|
||||||
|
implicit_ok = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
emotion_service = EmotionAnalyticsService()
|
||||||
|
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||||
|
end_user_id=end_user_id, db=db, language="zh"
|
||||||
|
)
|
||||||
|
await emotion_service.save_suggestions_cache(
|
||||||
|
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||||
|
)
|
||||||
|
emotion_ok = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||||
|
|
||||||
|
if implicit_ok or emotion_ok:
|
||||||
|
initialized += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
|
||||||
|
|
||||||
|
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||||
|
return {
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"initialized": initialized,
|
||||||
|
"skipped": skipped,
|
||||||
|
"failed": failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_closed():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
result["elapsed_time"] = time.time() - start_time
|
||||||
|
result["task_id"] = self.request.id
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "FAILURE",
|
||||||
|
"error": str(e),
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"task_id": self.request.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
name="app.tasks.init_interest_distribution_for_users",
|
||||||
|
bind=True,
|
||||||
|
ignore_result=True,
|
||||||
|
max_retries=0,
|
||||||
|
acks_late=False,
|
||||||
|
time_limit=3600,
|
||||||
|
soft_time_limit=3300,
|
||||||
|
)
|
||||||
|
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||||
|
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
|
||||||
|
|
||||||
|
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
|
||||||
|
默认生成中文(zh)兴趣分布数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_ids: 需要检查的用户ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含任务执行结果的字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||||
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||||
|
|
||||||
|
initialized = 0
|
||||||
|
failed = 0
|
||||||
|
skipped = 0
|
||||||
|
language = "zh"
|
||||||
|
|
||||||
|
service = MemoryAgentService()
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
for end_user_id in end_user_ids:
|
||||||
|
# 存在性检查:缓存有数据则跳过
|
||||||
|
cached = await InterestMemoryCache.get_interest_distribution(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
if cached is not None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||||
|
try:
|
||||||
|
result = await service.get_interest_distribution_by_user(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=5,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
await InterestMemoryCache.set_interest_distribution(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language=language,
|
||||||
|
data=result,
|
||||||
|
expire=INTEREST_CACHE_EXPIRE,
|
||||||
|
)
|
||||||
|
initialized += 1
|
||||||
|
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||||
|
|
||||||
|
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||||
|
return {
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"initialized": initialized,
|
||||||
|
"skipped": skipped,
|
||||||
|
"failed": failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_closed():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
result["elapsed_time"] = time.time() - start_time
|
||||||
|
result["task_id"] = self.request.id
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "FAILURE",
|
||||||
|
"error": str(e),
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"task_id": self.request.id,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,36 @@
|
|||||||
{
|
{
|
||||||
|
"v0.2.6": {
|
||||||
|
"introduction": {
|
||||||
|
"codeName": "听剑",
|
||||||
|
"releaseDate": "2026-3-6",
|
||||||
|
"upgradePosition": "🐻 多模态交互全面升级,记忆剪枝与工作流迁移双线并进,锋芒初露,兼收并蓄",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. 工作流与应用框架<br>* 工作流导入适配(Dify):支持 Dify 工作流定义无缝迁移<br>* 字段字数限制与校验规则:可配置字符限制与产品级校验<br>* 应用复制(Agent、工作流、集群):一键复制完整应用配置<br>* 对话变量(调试+分享):支持有状态多轮交互<br>* Chat 接口流式输出 message_id:流式响应包含消息追踪标识",
|
||||||
|
"2. 多模态与交互 💬<br>* 音频输入与输出:应用支持音频模态<br>* 文件类型输入支持:扩展支持语音、文件、视频上传",
|
||||||
|
"3. 模型与智能 🧠<br>* 模型视觉与 Omni 区分:精确区分视觉与 Omni 模型能力<br>* 教育记忆与陪伴玩具场景预设:垂直领域本体配置开箱即用<br>* 本体配置默认标识:支持基线配置标记<br>* 记忆配置默认标识:自动应用默认记忆设置",
|
||||||
|
"4. 记忆智能 🔬<br>* 记忆剪枝模块:智能裁剪冗余低价值记忆<br>* RAG 快速检索集成记忆:深度思考与正常回复双模式检索",
|
||||||
|
"5. 稳健性与缺陷修复 🔧<br>* 模型管理:修复自定义模型 API Key 批量配置错误<br>* 知识库管理:修复非源文档下载原始内容接口错误,更新分享停用提示文案<br>* 用户记忆:优化档案提取准确性(姓名、职业、兴趣分布)<br>* 长期记忆:修复情景记忆卡片重复和用户归属错误<br>* 工作空间首页:修复知识库数量、应用数量、总记忆容量、API 调用次数、知识库类型分布等数据不一致问题<br>* 基础设施:修正 Celery 环境变量配置,修复数据库连接池 idle-in-transaction 泄漏",
|
||||||
|
"<br>",
|
||||||
|
"v0.2.6 标志着 MemoryBear 在多模态交互、跨平台工作流迁移和智能记忆管理方面的重要突破。下一版本将聚焦 A2A 协议支持实现多智能体协作、多模态记忆能力扩展至语音与视觉领域,以及应用导入导出功能支持跨环境便携部署。",
|
||||||
|
"MemoryBear,让记忆有熊力 🐻✨"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"introduction_en": {
|
||||||
|
"codeName": "TingJian",
|
||||||
|
"releaseDate": "2026-3-6",
|
||||||
|
"upgradePosition": "🐻 Full multimodal interaction upgrade with memory pruning and workflow migration — sharpened edge, broader reach",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. Workflow & Application Framework<br>* Workflow Import Adaptation (Dify): Seamless Dify workflow migration<br>* Field Character Limits & Validation: Configurable limits with product-defined rules<br>* Application Cloning (Agent, Workflow, Cluster): One-click full config duplication<br>* Conversation Variables (Debug + Share): Stateful multi-turn interactions<br>* Streaming message_id in Chat API: Message tracking in streaming responses",
|
||||||
|
"2. Multimodal & Interaction 💬<br>* Audio Input & Output: Audio modality support for applications<br>* File Type Input Support: Voice, file, and video upload support",
|
||||||
|
"3. Model & Intelligence 🧠<br>* Model Vision & Omni Differentiation: Precise capability routing<br>* Education Memory & Companion Toy Presets: Domain-specific ontology configs<br>* Ontology Default Identifier: Baseline configuration flagging<br>* Memory Configuration Default Identifier: Auto-apply default settings",
|
||||||
|
"4. Memory Intelligence 🔬<br>* Memory Pruning Module: Intelligent trimming of redundant memories<br>* RAG Quick Retrieval with Memory: Deep think and normal reply dual-mode retrieval",
|
||||||
|
"5. Robustness & Bug Fixes 🔧<br>* Model Management: Fixed custom model API key batch configuration error<br>* Knowledge Base: Fixed download original content API error for non-source documents, updated share disable prompt text<br>* User Memory: Improved profile extraction accuracy (name, occupation, interests)<br>* Long-Term Memory: Fixed duplicate episodic memory cards and wrong user attribution<br>* Dashboard: Fixed data inconsistencies in knowledge count, app count, memory capacity, API calls, and knowledge type distribution<br>* Infrastructure: Corrected Celery environment variables, fixed database connection pool idle-in-transaction leak",
|
||||||
|
"<br>",
|
||||||
|
"v0.2.6 marks a significant milestone for MemoryBear in multimodal interaction, cross-platform workflow migration, and intelligent memory management. The next release will focus on A2A protocol support for multi-agent collaboration, multimodal memory extending extraction to voice and visual domains, and application import/export for portable cross-environment deployment.",
|
||||||
|
"MemoryBear, Memory with Bear Power 🐻✨"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"v0.2.5": {
|
"v0.2.5": {
|
||||||
"introduction": {
|
"introduction": {
|
||||||
"codeName": "行云",
|
"codeName": "行云",
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- celery
|
- celery
|
||||||
|
|
||||||
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency)
|
# Periodic worker - Scheduled/beat tasks + API-triggered tasks (prefork, low concurrency)
|
||||||
worker-periodic:
|
worker-periodic:
|
||||||
image: redbear-mem-open:latest
|
image: redbear-mem-open:latest
|
||||||
container_name: worker-periodic
|
container_name: worker-periodic
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ REDIS_DB=
|
|||||||
REDIS_PASSWORD=password
|
REDIS_PASSWORD=password
|
||||||
|
|
||||||
#celery
|
#celery
|
||||||
BROKER_URL=
|
# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
RESULT_BACKEND=
|
# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
|
||||||
CELERY_BROKER=
|
REDIS_DB_CELERY_BROKER=
|
||||||
CELERY_BACKEND=
|
REDIS_DB_CELERY_BACKEND=
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
# Interval in hours for regenerating memory insight and user summary cache
|
# Interval in hours for regenerating memory insight and user summary cache
|
||||||
|
|||||||
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
36
api/migrations/versions/1ac07dc7366f_202603061644.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""202603061644
|
||||||
|
|
||||||
|
Revision ID: 1ac07dc7366f
|
||||||
|
Revises: 6a4641cf192b
|
||||||
|
Create Date: 2026-03-06 16:51:10.152305
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '1ac07dc7366f'
|
||||||
|
down_revision: Union[str, None] = '6a4641cf192b'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('mcp_tool_configs', sa.Column('source_channel', sa.String(length=50), server_default=sa.text("'self_hosted'"), nullable=False, comment='来源渠道'))
|
||||||
|
op.add_column('mcp_tool_configs', sa.Column('market_id', sa.UUID(), nullable=True, comment='渠道市场id'))
|
||||||
|
op.add_column('mcp_tool_configs', sa.Column('market_config_id', sa.UUID(), nullable=True, comment='渠道市场配置id'))
|
||||||
|
op.add_column('mcp_tool_configs', sa.Column('mcp_service_id', sa.String(length=255), nullable=True, comment='mcp服务id'))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('mcp_tool_configs', 'mcp_service_id')
|
||||||
|
op.drop_column('mcp_tool_configs', 'market_config_id')
|
||||||
|
op.drop_column('mcp_tool_configs', 'market_id')
|
||||||
|
op.drop_column('mcp_tool_configs', 'source_channel')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { Query, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types'
|
import type { Query, MarketQuery, CustomToolItem, ExecuteData, MCPToolItem, InnerToolItem } from '@/views/ToolManagement/types'
|
||||||
|
|
||||||
// 工具列表
|
// 工具列表
|
||||||
export const getTools = (data: Query) => {
|
export const getTools = (data: Query) => {
|
||||||
@@ -33,4 +33,44 @@ export const getToolDetail = (tool_id: string) => {
|
|||||||
}
|
}
|
||||||
export const getToolMethods = (tool_id: string) => {
|
export const getToolMethods = (tool_id: string) => {
|
||||||
return request.get(`/tools/${tool_id}/methods`)
|
return request.get(`/tools/${tool_id}/methods`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP市场列表
|
||||||
|
export const getMarketTools = (data: Query) => {
|
||||||
|
return request.get('/mcp_markets/mcp_markets', data)
|
||||||
|
}
|
||||||
|
// 市场配置创建
|
||||||
|
export const createMarketConfig = (values: {
|
||||||
|
mcp_market_id: string;
|
||||||
|
token: string;
|
||||||
|
status: number;
|
||||||
|
}) => {
|
||||||
|
return request.post('/mcp_market_configs/mcp_market_config', values)
|
||||||
|
}
|
||||||
|
// 市场配置更新
|
||||||
|
export const updateMarketConfig = (values: {
|
||||||
|
mcp_market_config_id: string;
|
||||||
|
token: string;
|
||||||
|
status: number;
|
||||||
|
}) => {
|
||||||
|
return request.put(`/mcp_market_configs/${values.mcp_market_config_id}`, values)
|
||||||
|
}
|
||||||
|
// 市场根据id获取配置
|
||||||
|
export const getMarketConfig = (mcp_market_id: string) => {
|
||||||
|
return request.get(`/mcp_market_configs/mcp_market_id/${mcp_market_id}`)
|
||||||
|
}
|
||||||
|
// 市场MCP列表
|
||||||
|
export const getMarketMCPs = (data: MarketQuery) => {
|
||||||
|
return request.get('/mcp_market_configs/mcp_servers', data)
|
||||||
|
}
|
||||||
|
// 根据配置ID serverId 获取MCP服务详情
|
||||||
|
export const getMarketMCPDetail = (data:{
|
||||||
|
mcp_market_config_id: string;
|
||||||
|
server_id: string;
|
||||||
|
}) => {
|
||||||
|
return request.get(`/mcp_market_configs/mcp_server`,data)
|
||||||
|
}
|
||||||
|
// 市场已激活MCP列表
|
||||||
|
export const getMarketMCPsActivated = (data: MarketQuery) => {
|
||||||
|
return request.get('/mcp_market_configs/operational_mcp_servers', data)
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2025-12-10 16:46:14
|
* @Date: 2025-12-10 16:46:14
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 18:42:49
|
* @Last Modified time: 2026-03-06 13:36:20
|
||||||
*/
|
*/
|
||||||
import { type FC, useEffect, useMemo } from 'react'
|
import { type FC, useEffect, useMemo } from 'react'
|
||||||
import { Flex, Input, Form } from 'antd'
|
import { Flex, Input, Form } from 'antd'
|
||||||
@@ -50,13 +50,17 @@ const ChatInput: FC<ChatInputProps> = ({
|
|||||||
|
|
||||||
|
|
||||||
const handleDelete = (file: any) => {
|
const handleDelete = (file: any) => {
|
||||||
fileChange?.(fileList?.filter(item => item.uid !== file.uid) || [])
|
fileChange?.(fileList?.filter(item => {
|
||||||
|
return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl
|
||||||
|
: item.url && file.url ? item.url !== file.url
|
||||||
|
: item.uid !== file.uid
|
||||||
|
}) || [])
|
||||||
}
|
}
|
||||||
// Convert file object to preview URL
|
// Convert file object to preview URL
|
||||||
const previewFileList = useMemo(() => {
|
const previewFileList = useMemo(() => {
|
||||||
return fileList?.map(file => ({
|
return fileList?.map(file => ({
|
||||||
...file,
|
...file,
|
||||||
url: file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : file.thumbUrl)
|
url: file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
})) || []
|
})) || []
|
||||||
}, [fileList])
|
}, [fileList])
|
||||||
|
|
||||||
@@ -72,7 +76,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
|||||||
{previewFileList.map((file) => {
|
{previewFileList.map((file) => {
|
||||||
if (file.type.includes('image')) {
|
if (file.type.includes('image')) {
|
||||||
return (
|
return (
|
||||||
<div key={file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
<div key={file.url || file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||||
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
<img src={file.url} alt={file.name} className="rb:size-12! rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
<div
|
<div
|
||||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||||
@@ -83,7 +87,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
|||||||
}
|
}
|
||||||
if (file.type.includes('video')) {
|
if (file.type.includes('video')) {
|
||||||
return (
|
return (
|
||||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||||
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
<div
|
<div
|
||||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||||
@@ -94,7 +98,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
|||||||
}
|
}
|
||||||
if (file.type.includes('audio')) {
|
if (file.type.includes('audio')) {
|
||||||
return (
|
return (
|
||||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
<div key={file.url || file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||||
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
|
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
|
||||||
<div
|
<div
|
||||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||||
@@ -104,7 +108,7 @@ const ChatInput: FC<ChatInputProps> = ({
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<div key={file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
<div key={file.url || file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
||||||
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||||
></div>}
|
></div>}
|
||||||
|
|||||||
@@ -440,7 +440,6 @@ export const en = {
|
|||||||
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
|
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
|
||||||
publicApiCannotRefreshToken: 'Public API cannot refresh token',
|
publicApiCannotRefreshToken: 'Public API cannot refresh token',
|
||||||
refreshTokenNotExist: 'Refresh token does not exist',
|
refreshTokenNotExist: 'Refresh token does not exist',
|
||||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
|
|
||||||
reset: 'Reset',
|
reset: 'Reset',
|
||||||
refresh: 'Refresh',
|
refresh: 'Refresh',
|
||||||
return: 'Return',
|
return: 'Return',
|
||||||
@@ -1362,6 +1361,7 @@ export const en = {
|
|||||||
complex: 'Compatibility Analysis',
|
complex: 'Compatibility Analysis',
|
||||||
sureInfo: 'Information Confirmation',
|
sureInfo: 'Information Confirmation',
|
||||||
completed: 'Import Completed',
|
completed: 'Import Completed',
|
||||||
|
baseInfo: 'Basic Information',
|
||||||
workflowName: 'Workflow Name',
|
workflowName: 'Workflow Name',
|
||||||
fileName: 'File Name',
|
fileName: 'File Name',
|
||||||
fileSize: 'File Size',
|
fileSize: 'File Size',
|
||||||
@@ -1573,7 +1573,7 @@ export const en = {
|
|||||||
intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function',
|
intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function',
|
||||||
intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).',
|
intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).',
|
||||||
intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene',
|
intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene',
|
||||||
intelligentSemanticPruningSceneDesc: 'Select intelligent semantic pruning scene (education, online_service, outbound).',
|
intelligentSemanticPruningSceneDesc: 'Semantic pruning scenarios are consistent with ontology engineering scenarios',
|
||||||
intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold',
|
intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold',
|
||||||
intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).',
|
intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).',
|
||||||
reflectionEngine: 'Self-Reflexion Engine',
|
reflectionEngine: 'Self-Reflexion Engine',
|
||||||
@@ -1807,6 +1807,25 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
error_desc: 'API is configured but connection error',
|
error_desc: 'API is configured but connection error',
|
||||||
|
|
||||||
testConnectionSuccess: 'Test Connection Successful',
|
testConnectionSuccess: 'Test Connection Successful',
|
||||||
|
refreshSuccess: 'Refresh Successful',
|
||||||
|
refreshFailed: 'Refresh Failed',
|
||||||
|
|
||||||
|
// Market related
|
||||||
|
marketSelectTitle: 'Select an MCP Market',
|
||||||
|
marketSelectDesc: 'Choose a market source from the left, configure the connection to browse MCP services',
|
||||||
|
marketRefreshSuccess: 'List refreshed',
|
||||||
|
marketActivated: 'Activated',
|
||||||
|
marketInDatabase: 'In Database',
|
||||||
|
marketAdd: 'Add',
|
||||||
|
marketRefresh: 'Refresh',
|
||||||
|
marketConfig: 'Configure',
|
||||||
|
marketConfigConnection: 'Configure Connection',
|
||||||
|
marketNoServices: 'No MCP Services Available',
|
||||||
|
marketNotConnected: 'Not Connected to This Market',
|
||||||
|
marketNoServicesDesc: 'This market currently has no available services',
|
||||||
|
marketNotConnectedDesc: 'Click the "Configure" button in the upper right corner to set connection information',
|
||||||
|
marketSearchPlaceholder: 'Search services...',
|
||||||
|
marketVisit: 'Visit Market',
|
||||||
serviceEndpoint: 'Service Endpoint URL',
|
serviceEndpoint: 'Service Endpoint URL',
|
||||||
serviceEndpointPlaceholder: 'URL of the service endpoint',
|
serviceEndpointPlaceholder: 'URL of the service endpoint',
|
||||||
serviceEndpointExtra: 'Complete access address of the MCP service',
|
serviceEndpointExtra: 'Complete access address of the MCP service',
|
||||||
@@ -1960,6 +1979,19 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
viewDetail: 'View Details',
|
viewDetail: 'View Details',
|
||||||
textLink: 'Test Connection',
|
textLink: 'Test Connection',
|
||||||
noResult: 'Processing results will be displayed here',
|
noResult: 'Processing results will be displayed here',
|
||||||
|
|
||||||
|
marketConfig: 'Configure {{name}}',
|
||||||
|
marketSaveAndConnect: 'Save & Connect',
|
||||||
|
marketUrl: 'Market URL',
|
||||||
|
marketUrlPlaceholder: 'Market URL',
|
||||||
|
marketCopy: 'Copy',
|
||||||
|
marketApiKeyOptional: 'Optional',
|
||||||
|
marketApiKeyExtra: 'Some markets require an API Key to access the full service list',
|
||||||
|
marketApiKeyPlaceholder: 'Enter API Key to access more services',
|
||||||
|
marketConnectionStatus: 'Connection Status',
|
||||||
|
marketConnected: '● Connected',
|
||||||
|
marketDisconnected: '○ Disconnected',
|
||||||
|
marketConnecting: 'Connecting to {{name}}...',
|
||||||
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
||||||
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
||||||
},
|
},
|
||||||
@@ -2008,6 +2040,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
self_optimization: 'Self Optimization',
|
self_optimization: 'Self Optimization',
|
||||||
process_evolution: 'Process Evolution',
|
process_evolution: 'Process Evolution',
|
||||||
unknown: 'Unknown Node',
|
unknown: 'Unknown Node',
|
||||||
|
notes: 'Sticky Note',
|
||||||
|
|
||||||
clickToConfigure: 'Click to configure node parameters',
|
clickToConfigure: 'Click to configure node parameters',
|
||||||
nodeProperties: 'Node Properties',
|
nodeProperties: 'Node Properties',
|
||||||
@@ -2195,6 +2228,12 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
output_variables: 'Output Variables',
|
output_variables: 'Output Variables',
|
||||||
refreshTip: 'Sync function signature to code',
|
refreshTip: 'Sync function signature to code',
|
||||||
},
|
},
|
||||||
|
notes: {
|
||||||
|
showAuth: 'Show Author',
|
||||||
|
enterLink: 'Enter Link URL',
|
||||||
|
placeholder: 'Enter note...',
|
||||||
|
removeLink: 'Remove Link',
|
||||||
|
},
|
||||||
name: 'Key',
|
name: 'Key',
|
||||||
type: 'Type',
|
type: 'Type',
|
||||||
value: 'Value',
|
value: 'Value',
|
||||||
@@ -2617,6 +2656,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
updated_at: 'Updated At',
|
updated_at: 'Updated At',
|
||||||
entityTypes: 'Entity Types',
|
entityTypes: 'Entity Types',
|
||||||
|
|
||||||
|
classSearchPlaceholder: 'Search types',
|
||||||
addClass: 'Add Type',
|
addClass: 'Add Type',
|
||||||
class_name: 'Type Name',
|
class_name: 'Type Name',
|
||||||
class_description: 'Type Definition',
|
class_description: 'Type Definition',
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ export const zh = {
|
|||||||
createMemorySummary: '创建记忆摘要',
|
createMemorySummary: '创建记忆摘要',
|
||||||
memoryManagement: '记忆管理',
|
memoryManagement: '记忆管理',
|
||||||
spaceManagement: '空间管理',
|
spaceManagement: '空间管理',
|
||||||
memoryExtractionEngine: '记忆提取引擎',
|
memoryExtractionEngine: '记忆萃取引擎',
|
||||||
forgettingEngine: '遗忘引擎',
|
forgettingEngine: '遗忘引擎',
|
||||||
apiKeyManagement: 'API KEY管理',
|
apiKeyManagement: 'API KEY管理',
|
||||||
knowledgePrivate: '详情',
|
knowledgePrivate: '详情',
|
||||||
@@ -1020,7 +1020,6 @@ export const zh = {
|
|||||||
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
|
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
|
||||||
publicApiCannotRefreshToken: '公共接口不能刷新token',
|
publicApiCannotRefreshToken: '公共接口不能刷新token',
|
||||||
refreshTokenNotExist: '刷新token不存在',
|
refreshTokenNotExist: '刷新token不存在',
|
||||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
|
|
||||||
reset: '重置',
|
reset: '重置',
|
||||||
refresh: '刷新',
|
refresh: '刷新',
|
||||||
return: '返回',
|
return: '返回',
|
||||||
@@ -1284,7 +1283,7 @@ export const zh = {
|
|||||||
createConfiguration: '创建配置',
|
createConfiguration: '创建配置',
|
||||||
editConfiguration: '编辑配置',
|
editConfiguration: '编辑配置',
|
||||||
desc: '描述',
|
desc: '描述',
|
||||||
memoryExtractionEngine: '记忆提取引擎',
|
memoryExtractionEngine: '记忆萃取引擎',
|
||||||
forgottenEngine: '遗忘引擎',
|
forgottenEngine: '遗忘引擎',
|
||||||
active: '活跃',
|
active: '活跃',
|
||||||
inactive: '不活跃',
|
inactive: '不活跃',
|
||||||
@@ -1572,7 +1571,7 @@ export const zh = {
|
|||||||
intelligentSemanticPruningFunction: '智能语义修剪功能',
|
intelligentSemanticPruningFunction: '智能语义修剪功能',
|
||||||
intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪(true/false)。',
|
intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪(true/false)。',
|
||||||
intelligentSemanticPruningScene: '智能语义修剪场景',
|
intelligentSemanticPruningScene: '智能语义修剪场景',
|
||||||
intelligentSemanticPruningSceneDesc: '选择智能语义修剪场景(education、online_service、outbound)。',
|
intelligentSemanticPruningSceneDesc: '语义剪枝场景与本体工程场景一致',
|
||||||
intelligentSemanticPruningThreshold: '智能语义修剪阈值',
|
intelligentSemanticPruningThreshold: '智能语义修剪阈值',
|
||||||
intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值(0-0.9)。',
|
intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值(0-0.9)。',
|
||||||
reflectionEngine: '自我反思引擎',
|
reflectionEngine: '自我反思引擎',
|
||||||
@@ -1804,6 +1803,25 @@ export const zh = {
|
|||||||
error_desc: 'API 已配置但链接异常',
|
error_desc: 'API 已配置但链接异常',
|
||||||
|
|
||||||
testConnectionSuccess: '测试连接成功',
|
testConnectionSuccess: '测试连接成功',
|
||||||
|
refreshSuccess: '刷新成功',
|
||||||
|
refreshFailed: '刷新失败',
|
||||||
|
|
||||||
|
// Market 相关
|
||||||
|
marketSelectTitle: '选择一个 MCP 市场',
|
||||||
|
marketSelectDesc: '从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务',
|
||||||
|
marketRefreshSuccess: '列表已刷新',
|
||||||
|
marketActivated: '已激活',
|
||||||
|
marketInDatabase: '已入库',
|
||||||
|
marketAdd: '添加',
|
||||||
|
marketRefresh: '刷新',
|
||||||
|
marketConfig: '配置',
|
||||||
|
marketConfigConnection: '配置连接',
|
||||||
|
marketNoServices: '暂无可用的 MCP 服务',
|
||||||
|
marketNotConnected: '尚未连接此市场',
|
||||||
|
marketNoServicesDesc: '该市场暂时没有可用的服务',
|
||||||
|
marketNotConnectedDesc: '点击右上角"配置"按钮设置连接信息',
|
||||||
|
marketSearchPlaceholder: '搜索服务...',
|
||||||
|
marketVisit: '前往市场',
|
||||||
serviceEndpoint: '服务端点 URL',
|
serviceEndpoint: '服务端点 URL',
|
||||||
serviceEndpointPlaceholder: '服务端点的 URL',
|
serviceEndpointPlaceholder: '服务端点的 URL',
|
||||||
serviceEndpointExtra: 'MCP服务的完整访问地址',
|
serviceEndpointExtra: 'MCP服务的完整访问地址',
|
||||||
@@ -1957,6 +1975,19 @@ export const zh = {
|
|||||||
viewDetail: '查看详情',
|
viewDetail: '查看详情',
|
||||||
textLink: '测试连接',
|
textLink: '测试连接',
|
||||||
noResult: '处理结果将显示在这里',
|
noResult: '处理结果将显示在这里',
|
||||||
|
|
||||||
|
marketConfig: '配置 {{name}}',
|
||||||
|
marketSaveAndConnect: '保存并连接',
|
||||||
|
marketUrl: '市场地址',
|
||||||
|
marketUrlPlaceholder: '市场地址',
|
||||||
|
marketCopy: '复制',
|
||||||
|
marketApiKeyOptional: '可选',
|
||||||
|
marketApiKeyExtra: '部分市场需要 API Key 才能获取完整的服务列表',
|
||||||
|
marketApiKeyPlaceholder: '输入 API Key 以获取更多服务',
|
||||||
|
marketConnectionStatus: '连接状态',
|
||||||
|
marketConnected: '● 已连接',
|
||||||
|
marketDisconnected: '○ 未连接',
|
||||||
|
marketConnecting: '正在连接 {{name}}...',
|
||||||
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
||||||
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
||||||
},
|
},
|
||||||
@@ -2005,6 +2036,7 @@ export const zh = {
|
|||||||
self_optimization: '自我优化',
|
self_optimization: '自我优化',
|
||||||
process_evolution: '流程演化',
|
process_evolution: '流程演化',
|
||||||
unknown: '未知节点',
|
unknown: '未知节点',
|
||||||
|
notes: '便签',
|
||||||
|
|
||||||
clickToConfigure: '点击配置节点参数',
|
clickToConfigure: '点击配置节点参数',
|
||||||
nodeProperties: '节点属性',
|
nodeProperties: '节点属性',
|
||||||
@@ -2195,6 +2227,12 @@ export const zh = {
|
|||||||
unknown: {
|
unknown: {
|
||||||
replaceNodeType: '替换节点'
|
replaceNodeType: '替换节点'
|
||||||
},
|
},
|
||||||
|
notes: {
|
||||||
|
showAuth: '显示作者',
|
||||||
|
enterLink: '输入链接 URL',
|
||||||
|
placeholder: '输入注释...',
|
||||||
|
removeLink: '取消链接',
|
||||||
|
},
|
||||||
name: '键',
|
name: '键',
|
||||||
type: '类型',
|
type: '类型',
|
||||||
value: '值',
|
value: '值',
|
||||||
@@ -2618,6 +2656,7 @@ export const zh = {
|
|||||||
updated_at: '更新时间',
|
updated_at: '更新时间',
|
||||||
entityTypes: '实体类型',
|
entityTypes: '实体类型',
|
||||||
|
|
||||||
|
classSearchPlaceholder: '搜索类型',
|
||||||
addClass: '添加类型',
|
addClass: '添加类型',
|
||||||
class_name: '类型名称',
|
class_name: '类型名称',
|
||||||
class_description: '类型定义',
|
class_description: '类型定义',
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-02 16:35:15
|
* @Date: 2026-02-02 16:35:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-02 16:35:15
|
* @Last Modified time: 2026-03-06 10:39:00
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* HTTP Request Utility Module
|
* HTTP Request Utility Module
|
||||||
@@ -183,7 +183,7 @@ service.interceptors.response.use(
|
|||||||
msg = msg || i18n.t('common.serverError');
|
msg = msg || i18n.t('common.serverError');
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') {
|
if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
|
||||||
msg = i18n.t(`common.${msg}`)
|
msg = i18n.t(`common.${msg}`)
|
||||||
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
||||||
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
||||||
@@ -356,12 +356,11 @@ export const request = {
|
|||||||
* Get parent domain for cookie setting
|
* Get parent domain for cookie setting
|
||||||
* @returns Parent domain or IP address
|
* @returns Parent domain or IP address
|
||||||
*/
|
*/
|
||||||
|
const isIp = (hostname: string) => /^\d+\.\d+\.\d+\.\d+$/.test(hostname)
|
||||||
|
|
||||||
const getParentDomain = () => {
|
const getParentDomain = () => {
|
||||||
const hostname = window.location.hostname
|
const hostname = window.location.hostname
|
||||||
// Check if it's an IP address
|
if (isIp(hostname)) return hostname
|
||||||
if (/^\d+\.\d+\.\d+\.\d+$/.test(hostname)) {
|
|
||||||
return hostname
|
|
||||||
}
|
|
||||||
const parts = hostname.split('.')
|
const parts = hostname.split('.')
|
||||||
return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname
|
return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname
|
||||||
}
|
}
|
||||||
@@ -371,7 +370,10 @@ const getParentDomain = () => {
|
|||||||
*/
|
*/
|
||||||
export const cookieUtils = {
|
export const cookieUtils = {
|
||||||
set: (name: string, value: string, domain = getParentDomain()) => {
|
set: (name: string, value: string, domain = getParentDomain()) => {
|
||||||
document.cookie = `${name}=${value}; domain=${domain}; path=/; secure; samesite=strict`
|
const ip = isIp(window.location.hostname)
|
||||||
|
const domainPart = ip ? '' : `; domain=${domain}`
|
||||||
|
const securePart = window.location.protocol === 'https:' ? '; secure' : ''
|
||||||
|
document.cookie = `${name}=${value}${domainPart}; path=/${securePart}; samesite=strict`
|
||||||
},
|
},
|
||||||
get: (name: string) => {
|
get: (name: string) => {
|
||||||
const value = `; ${document.cookie}`
|
const value = `; ${document.cookie}`
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:27:39
|
* @Date: 2026-02-03 16:27:39
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 18:51:20
|
* @Last Modified time: 2026-03-05 17:03:46
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Chat debugging component for application testing
|
* Chat debugging component for application testing
|
||||||
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
|||||||
.then(() => {
|
.then(() => {
|
||||||
const message = msg
|
const message = msg
|
||||||
if (!message?.trim()) return
|
if (!message?.trim()) return
|
||||||
|
// Validate required variables before sending
|
||||||
|
let isCanSend = true
|
||||||
|
const params: Record<string, any> = {}
|
||||||
|
if (chatVariables && chatVariables.length > 0) {
|
||||||
|
const needRequired: string[] = []
|
||||||
|
chatVariables.forEach(vo => {
|
||||||
|
params[vo.name] = vo.value
|
||||||
|
|
||||||
|
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||||
|
isCanSend = false
|
||||||
|
needRequired.push(vo.name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (needRequired.length) {
|
||||||
|
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!isCanSend) {
|
||||||
|
setLoading(false)
|
||||||
|
setCompareLoading(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
addUserMessage(message, fileList)
|
addUserMessage(message, fileList)
|
||||||
setMessage(message)
|
setMessage(message)
|
||||||
@@ -198,29 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
|||||||
};
|
};
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
// Validate required variables before sending
|
|
||||||
let isCanSend = true
|
|
||||||
const params: Record<string, any> = {}
|
|
||||||
if (chatVariables && chatVariables.length > 0) {
|
|
||||||
const needRequired: string[] = []
|
|
||||||
chatVariables.forEach(vo => {
|
|
||||||
params[vo.name] = vo.value
|
|
||||||
|
|
||||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
|
||||||
isCanSend = false
|
|
||||||
needRequired.push(vo.name)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if (needRequired.length) {
|
|
||||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!isCanSend) {
|
|
||||||
setLoading(false)
|
|
||||||
setCompareLoading(false)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
runCompare(data.app_id, {
|
runCompare(data.app_id, {
|
||||||
message,
|
message,
|
||||||
files: fileList.map(file => {
|
files: fileList.map(file => {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-28 14:08:14
|
* @Date: 2026-02-28 14:08:14
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-02 17:39:49
|
* @Last Modified time: 2026-03-06 12:05:46
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* UploadWorkflowModal Component
|
* UploadWorkflowModal Component
|
||||||
@@ -101,6 +101,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
formData.append('platform', values.platform);
|
formData.append('platform', values.platform);
|
||||||
formData.append('file', values.file[0]);
|
formData.append('file', values.file[0]);
|
||||||
|
|
||||||
|
setLoading(true)
|
||||||
// Call import workflow API
|
// Call import workflow API
|
||||||
importWorkflow(formData)
|
importWorkflow(formData)
|
||||||
.then(res => {
|
.then(res => {
|
||||||
@@ -114,21 +115,24 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
} else {
|
} else {
|
||||||
setCurrent(2);
|
setCurrent(2);
|
||||||
// Pre-fill form with file information
|
// Pre-fill form with file information
|
||||||
|
const fileNameSplit = values.file[0].name.split('.')
|
||||||
form.setFieldsValue({
|
form.setFieldsValue({
|
||||||
name: values.file[0].name.split('.')[0],
|
name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
|
||||||
platform: values.platform,
|
platform: values.platform,
|
||||||
fileName: values.file[0].name,
|
fileName: values.file[0].name,
|
||||||
fileSize: values.file[0].size,
|
fileSize: values.file[0].size,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
})
|
||||||
|
.finally(() => setLoading(false));
|
||||||
break;
|
break;
|
||||||
case 1: // Step 2: Error/warning display
|
case 1: // Step 2: Error/warning display
|
||||||
if (firstFormData) {
|
if (firstFormData) {
|
||||||
const { file, platform } = firstFormData;
|
const { file, platform } = firstFormData;
|
||||||
|
const fileNameSplit = firstFormData.file[0].name.split('.')
|
||||||
// Pre-fill form with file information
|
// Pre-fill form with file information
|
||||||
form.setFieldsValue({
|
form.setFieldsValue({
|
||||||
name: file[0].name.split('.')[0],
|
name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'),
|
||||||
platform: platform,
|
platform: platform,
|
||||||
fileName: file[0].name,
|
fileName: file[0].name,
|
||||||
fileSize: file[0].size,
|
fileSize: file[0].size,
|
||||||
@@ -138,6 +142,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
break;
|
break;
|
||||||
case 2: // Step 3: Confirm information
|
case 2: // Step 3: Confirm information
|
||||||
if (data) {
|
if (data) {
|
||||||
|
setLoading(true);
|
||||||
// Complete import workflow
|
// Complete import workflow
|
||||||
completeImportWorkflow({
|
completeImportWorkflow({
|
||||||
temp_id: data.temp_id,
|
temp_id: data.temp_id,
|
||||||
@@ -148,7 +153,8 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
const response = res as { id: string };
|
const response = res as { id: string };
|
||||||
setCurrent(3);
|
setCurrent(3);
|
||||||
setAppId(response.id);
|
setAppId(response.id);
|
||||||
});
|
})
|
||||||
|
.finally(() => setLoading(false));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@@ -175,7 +181,9 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset form if not going back to error/warning step
|
// Reset form if not going back to error/warning step
|
||||||
if (newStep !== 1) {
|
if (newStep === 0) {
|
||||||
|
form.setFieldsValue(firstFormData || {})
|
||||||
|
} else if (newStep !== 1) {
|
||||||
form.resetFields();
|
form.resetFields();
|
||||||
}
|
}
|
||||||
setCurrent(newStep);
|
setCurrent(newStep);
|
||||||
@@ -186,14 +194,16 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
* @param {string} type - Navigation type ('detail' or 'list')
|
* @param {string} type - Navigation type ('detail' or 'list')
|
||||||
*/
|
*/
|
||||||
const handleJump = (type: string) => {
|
const handleJump = (type: string) => {
|
||||||
switch(type) {
|
|
||||||
case 'detail':
|
|
||||||
// Open application detail page in new tab
|
|
||||||
window.open(`/#/application/config/${appId}`, '_blank');
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
refresh();
|
|
||||||
handleClose();
|
handleClose();
|
||||||
|
refresh();
|
||||||
|
setTimeout(() => {
|
||||||
|
switch (type) {
|
||||||
|
case 'detail':
|
||||||
|
// Open application detail page in new tab
|
||||||
|
window.open(`/#/application/config/${appId}`, '_blank');
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}, 100)
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -235,7 +245,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
</Button>
|
</Button>
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
}, [current]);
|
}, [current, loading]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RbModal
|
<RbModal
|
||||||
@@ -350,7 +360,7 @@ const UploadWorkflowModal = forwardRef<UploadWorkflowModalRef, UploadWorkflowMod
|
|||||||
title={t('application.importSuccess')}
|
title={t('application.importSuccess')}
|
||||||
subTitle={t('application.importSuccessDesc')}
|
subTitle={t('application.importSuccessDesc')}
|
||||||
extra={[
|
extra={[
|
||||||
<Button key="back" onClick={() => handleJump('list')}>
|
<Button key="back" onClick={() => handleJump('list')}>
|
||||||
{t('application.gotoList')}
|
{t('application.gotoList')}
|
||||||
</Button>,
|
</Button>,
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:09:42
|
* @Date: 2026-02-06 21:09:42
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-05 15:09:22
|
* @Last Modified time: 2026-03-06 12:20:43
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* File Upload Component
|
* File Upload Component
|
||||||
@@ -208,6 +208,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
|||||||
newFileList.map(file => {
|
newFileList.map(file => {
|
||||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
||||||
file.type = type
|
file.type = type
|
||||||
|
file.thumbUrl = file.thumbUrl || URL.createObjectURL(file.originFileObj as Blob)
|
||||||
})
|
})
|
||||||
setFileList(newFileList);
|
setFileList(newFileList);
|
||||||
if (onChange) {
|
if (onChange) {
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ const CreateDataset = () => {
|
|||||||
const [form] = Form.useForm<ContentFormData>();
|
const [form] = Form.useForm<ContentFormData>();
|
||||||
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
|
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
|
||||||
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
|
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
|
||||||
|
const [textFormValid, setTextFormValid] = useState<boolean>(false);
|
||||||
|
|
||||||
const [pollingLoading, setPollingLoading] = useState<boolean>(false);
|
const [pollingLoading, setPollingLoading] = useState<boolean>(false);
|
||||||
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||||
@@ -624,7 +625,16 @@ const CreateDataset = () => {
|
|||||||
)}
|
)}
|
||||||
{source && source === 'text' && (
|
{source && source === 'text' && (
|
||||||
<div className='rb:flex rb:w-full rb:flex-col rb:mt-10 rb:px-40'>
|
<div className='rb:flex rb:w-full rb:flex-col rb:mt-10 rb:px-40'>
|
||||||
<Form form={form} layout="vertical">
|
<Form
|
||||||
|
form={form}
|
||||||
|
layout="vertical"
|
||||||
|
onValuesChange={() => {
|
||||||
|
// 检查表单字段是否都已填写
|
||||||
|
const values = form.getFieldsValue();
|
||||||
|
const isValid = !!(values.title?.trim() && values.content?.trim());
|
||||||
|
setTextFormValid(isValid);
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
name="title"
|
name="title"
|
||||||
label={t('knowledgeBase.title')}
|
label={t('knowledgeBase.title')}
|
||||||
@@ -845,7 +855,11 @@ const CreateDataset = () => {
|
|||||||
<Button
|
<Button
|
||||||
type='primary'
|
type='primary'
|
||||||
onClick={current === 2 ? handleStartUpload : handleNext}
|
onClick={current === 2 ? handleStartUpload : handleNext}
|
||||||
disabled={pollingLoading || (current === 0 && rechunkFileIds.length === 0)}
|
disabled={
|
||||||
|
pollingLoading ||
|
||||||
|
(current === 0 && source === 'local' && rechunkFileIds.length === 0) ||
|
||||||
|
(current === 0 && source === 'text' && !textFormValid)
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{current === 2 ? t('knowledgeBase.startUploading') || 'Start Upload' : t('common.next') || 'Next'}
|
{current === 2 ? t('knowledgeBase.startUploading') || 'Start Upload' : t('common.next') || 'Next'}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -672,9 +672,17 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
|||||||
{currentType !== 'Folder' && dynamicTypeList.map((tp) => {
|
{currentType !== 'Folder' && dynamicTypeList.map((tp) => {
|
||||||
const fieldKey = typeToFieldKey(tp);
|
const fieldKey = typeToFieldKey(tp);
|
||||||
// When tp is 'llm', merge llm and chat options
|
// When tp is 'llm', merge llm and chat options
|
||||||
const options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
|
let options = tp.toLowerCase() === 'llm' || tp.toLowerCase() === 'image2text'
|
||||||
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
|
? [...(modelOptionsByType['llm'] || []), ...(modelOptionsByType['chat'] || [])]
|
||||||
: modelOptionsByType[tp] || [];
|
: modelOptionsByType[tp] || [];
|
||||||
|
|
||||||
|
// When tp is 'image2text', filter to only include models with 'vision' capability
|
||||||
|
if (tp.toLowerCase() === 'image2text') {
|
||||||
|
options = options.filter((opt: any) => {
|
||||||
|
const model = models?.items?.find((m: any) => m.id === opt.value);
|
||||||
|
return model?.capability?.includes('vision');
|
||||||
|
});
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<Form.Item
|
<Form.Item
|
||||||
key={tp}
|
key={tp}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* @Author: yujiangping
|
* @Author: yujiangping
|
||||||
* @Date: 2025-11-10 18:52:55
|
* @Date: 2025-11-10 18:52:55
|
||||||
* @LastEditors: yujiangping
|
* @LastEditors: yujiangping
|
||||||
* @LastEditTime: 2026-03-03 14:46:08
|
* @LastEditTime: 2026-03-09 16:39:07
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
|
import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
|
||||||
import { Switch } from 'antd';
|
import { Switch } from 'antd';
|
||||||
@@ -58,16 +58,21 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const handleShare = async() => {
|
const handleShare = async() => {
|
||||||
const workspaceIds = spaceList
|
setLoading(true);
|
||||||
.map(item => item.target_kb?.workspace_id)
|
try {
|
||||||
.filter(Boolean)
|
const workspaceIds = spaceList
|
||||||
.join(',');
|
.map(item => item.target_kb?.workspace_id)
|
||||||
|
.filter(Boolean)
|
||||||
console.log('Workspace IDs:', workspaceIds);
|
.join(',');
|
||||||
shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
|
|
||||||
|
console.log('Workspace IDs:', workspaceIds);
|
||||||
// Close modal after sharing
|
shareSpaceModalRef?.current?.handleOpen(kbId,knowledgeBase,workspaceIds);
|
||||||
handleClose();
|
|
||||||
|
// Close modal after sharing
|
||||||
|
handleClose();
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const handleChange = (checked: boolean, item: any) => {
|
const handleChange = (checked: boolean, item: any) => {
|
||||||
// Toggle shared knowledge base status
|
// Toggle shared knowledge base status
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* @Author: yujiangping
|
* @Author: yujiangping
|
||||||
* @Date: 2025-11-10 18:52:55
|
* @Date: 2025-11-10 18:52:55
|
||||||
* @LastEditors: yujiangping
|
* @LastEditors: yujiangping
|
||||||
* @LastEditTime: 2025-12-03 18:44:58
|
* @LastEditTime: 2026-03-09 16:34:51
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||||
import { Switch } from 'antd';
|
import { Switch } from 'antd';
|
||||||
@@ -50,34 +50,38 @@ const ShareModal = forwardRef<ShareModalRef,ShareModalRefProps>(({ handleShare:
|
|||||||
setSpaceList(filteredItems as SpaceItem[]);
|
setSpaceList(filteredItems as SpaceItem[]);
|
||||||
}
|
}
|
||||||
const handleShare = async() => {
|
const handleShare = async() => {
|
||||||
|
|
||||||
// Get all data with checked = true
|
// Get all data with checked = true
|
||||||
const checkedItems = spaceList.filter(item => item.is_active);
|
const checkedItems = spaceList.filter(item => item.is_active);
|
||||||
debugger
|
|
||||||
// Get currently selected item (corresponding to curIndex)
|
// Get currently selected item (corresponding to curIndex)
|
||||||
const selectedItem = curIndex !== -1 ? spaceList[curIndex] : null;
|
const selectedItem = curIndex !== -1 ? spaceList[curIndex] : null;
|
||||||
if(!selectedItem){
|
if(!selectedItem){
|
||||||
messageApi.error(t('knowledgeBase.selectSpace'));
|
messageApi.error(t('knowledgeBase.selectSpace'));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const payload = {
|
|
||||||
source_kb_id: kbId ?? '',
|
|
||||||
target_workspace_id: selectedItem?.id ?? '',
|
|
||||||
}
|
|
||||||
const respose = await shareKnowledgeBase(payload)
|
|
||||||
if(respose){
|
|
||||||
messageApi.success(t('knowledgeBase.shareSuccess'));
|
|
||||||
}else{
|
|
||||||
messageApi.error(t('knowledgeBase.shareFailed'));
|
|
||||||
}
|
|
||||||
// Call parent component's callback function with selected data
|
|
||||||
onShare?.({
|
|
||||||
checkedItems,
|
|
||||||
selectedItem
|
|
||||||
});
|
|
||||||
|
|
||||||
// Close modal after sharing
|
setLoading(true);
|
||||||
handleClose();
|
try {
|
||||||
|
const payload = {
|
||||||
|
source_kb_id: kbId ?? '',
|
||||||
|
target_workspace_id: selectedItem?.id ?? '',
|
||||||
|
}
|
||||||
|
const respose = await shareKnowledgeBase(payload)
|
||||||
|
if(respose){
|
||||||
|
messageApi.success(t('knowledgeBase.shareSuccess'));
|
||||||
|
}else{
|
||||||
|
messageApi.error(t('knowledgeBase.shareFailed'));
|
||||||
|
}
|
||||||
|
// Call parent component's callback function with selected data
|
||||||
|
onShare?.({
|
||||||
|
checkedItems,
|
||||||
|
selectedItem
|
||||||
|
});
|
||||||
|
|
||||||
|
// Close modal after sharing
|
||||||
|
handleClose();
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const handleClick = (index: number, checked: boolean) => {
|
const handleClick = (index: number, checked: boolean) => {
|
||||||
if (!checked) return;
|
if (!checked) return;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 17:30:06
|
* @Date: 2026-02-03 17:30:06
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-04 10:09:45
|
* @Last Modified time: 2026-03-06 13:49:00
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Memory Extraction Engine Configuration Constants
|
* Memory Extraction Engine Configuration Constants
|
||||||
@@ -140,13 +140,8 @@ export const configList: ConfigVo[] = [
|
|||||||
{
|
{
|
||||||
label: 'intelligentSemanticPruningScene',
|
label: 'intelligentSemanticPruningScene',
|
||||||
variableName: 'pruning_scene',
|
variableName: 'pruning_scene',
|
||||||
control: 'select',
|
control: 'text',
|
||||||
type: 'enum',
|
type: 'enum',
|
||||||
options: [
|
|
||||||
{ label: 'education', value: 'education' },
|
|
||||||
{ label: 'online_service', value: 'online_service' },
|
|
||||||
{ label: 'outbound', value: 'outbound' },
|
|
||||||
],
|
|
||||||
meaning: 'intelligentSemanticPruningSceneDesc',
|
meaning: 'intelligentSemanticPruningSceneDesc',
|
||||||
},
|
},
|
||||||
// Intelligent semantic pruning阈值
|
// Intelligent semantic pruning阈值
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 17:30:02
|
* @Date: 2026-02-03 17:30:02
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-03 17:30:02
|
* @Last Modified time: 2026-03-06 13:50:05
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Memory Extraction Engine Configuration Page
|
* Memory Extraction Engine Configuration Page
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
import { type FC, useState, useEffect } from 'react'
|
import { type FC, useState, useEffect } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { useParams } from 'react-router-dom'
|
import { useParams } from 'react-router-dom'
|
||||||
import { Row, Col, Space, Select, InputNumber, Slider, App, Form } from 'antd'
|
import { Row, Col, Space, Select, InputNumber, Slider, App, Form, Input } from 'antd'
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
|
|
||||||
import Card from './components/Card'
|
import Card from './components/Card'
|
||||||
@@ -35,15 +35,15 @@ const keys = [
|
|||||||
/**
|
/**
|
||||||
* Configuration description component
|
* Configuration description component
|
||||||
*/
|
*/
|
||||||
const ConfigDesc: FC<{ config: Variable, className?: string }> = ({config, className}) => {
|
const ConfigDesc: FC<{ config: Variable, className?: string; onlyMeaning?: boolean; }> = ({ config, className, onlyMeaning = false}) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<div className={className}>
|
<div className={className}>
|
||||||
<Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>
|
{!onlyMeaning && <Space size={8} className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>
|
||||||
{config.variableName && <span className="rb:font-regular">{t('memoryExtractionEngine.variableName')}: {config.variableName}</span>}
|
{config.variableName && <span className="rb:font-regular">{t('memoryExtractionEngine.variableName')}: {config.variableName}</span>}
|
||||||
{config.control && <span className="rb:font-regular">{t('memoryExtractionEngine.control')}: {t(`memoryExtractionEngine.${config.control}`)}</span>}
|
{config.control && <span className="rb:font-regular">{t('memoryExtractionEngine.control')}: {t(`memoryExtractionEngine.${config.control}`)}</span>}
|
||||||
{config.type && <span className="rb:font-regular">{t('memoryExtractionEngine.type')}: {config.type}</span>}
|
{config.type && <span className="rb:font-regular">{t('memoryExtractionEngine.type')}: {config.type}</span>}
|
||||||
</Space>
|
</Space>}
|
||||||
{config.meaning && <div className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>{t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}</div>}
|
{config.meaning && <div className={clsx("rb:mt-1 rb:text-[12px] rb:text-[#5B6167] rb:font-regular rb:leading-4 ")}>{t('memoryExtractionEngine.Meaning')}: {t(`memoryExtractionEngine.${config.meaning}`)}</div>}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
@@ -253,6 +253,21 @@ const MemoryExtractionEngine: FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
|
{config.control === 'text' &&
|
||||||
|
<>
|
||||||
|
<div className="rb:text-[14px] rb:font-medium rb:leading-5 rb:mt-6 rb:mb-2">
|
||||||
|
-{t(`memoryExtractionEngine.${config.label}`)}
|
||||||
|
</div>
|
||||||
|
<div className="rb:pl-2">
|
||||||
|
<Form.Item
|
||||||
|
name={config.variableName}
|
||||||
|
>
|
||||||
|
<Input placeholder={t('common.pleaseEnter')} disabled />
|
||||||
|
</Form.Item>
|
||||||
|
<ConfigDesc config={config} onlyMeaning={true} className="rb:-mt-4!" />
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 17:33:15
|
* @Date: 2026-02-03 17:33:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-05 16:28:58
|
* @Last Modified time: 2026-03-06 13:53:53
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Memory Management Page
|
* Memory Management Page
|
||||||
@@ -154,10 +154,10 @@ const MemoryManagement: React.FC = () => {
|
|||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
||||||
onClick={() => handleEdit(item)}
|
onClick={() => handleEdit(item)}
|
||||||
></div>
|
></div>
|
||||||
<div
|
{!item.is_system_default && <div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={() => handleDelete(item)}
|
onClick={() => handleDelete(item)}
|
||||||
></div>
|
></div>}
|
||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
</RbCard>
|
</RbCard>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:49:45
|
* @Date: 2026-02-03 16:49:45
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 11:50:47
|
* @Last Modified time: 2026-03-06 12:26:12
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Model List Detail Drawer
|
* Model List Detail Drawer
|
||||||
@@ -144,7 +144,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
|||||||
{item.name[0]}
|
{item.name[0]}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
extra={<Switch defaultChecked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
|
extra={<Switch checked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
|
||||||
bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!"
|
bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!"
|
||||||
>
|
>
|
||||||
<Tooltip title={item.description}>
|
<Tooltip title={item.description}>
|
||||||
@@ -153,7 +153,7 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
|||||||
<div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6">
|
<div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6">
|
||||||
<Row gutter={12}>
|
<Row gutter={12}>
|
||||||
<Col span={12}>
|
<Col span={12}>
|
||||||
<Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>
|
{!item.model_id && <Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>}
|
||||||
</Col>
|
</Col>
|
||||||
<Col span={12}>
|
<Col span={12}>
|
||||||
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:50:18
|
* @Date: 2026-02-03 16:50:18
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 11:39:20
|
* @Last Modified time: 2026-03-06 12:26:11
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Type definitions for Model Management
|
* Type definitions for Model Management
|
||||||
@@ -121,6 +121,7 @@ export interface ModelApiKey {
|
|||||||
* Model list item data structure
|
* Model list item data structure
|
||||||
*/
|
*/
|
||||||
export interface ModelListItem {
|
export interface ModelListItem {
|
||||||
|
model_id?: string;
|
||||||
/** Model name */
|
/** Model name */
|
||||||
model_name?: string;
|
model_name?: string;
|
||||||
/** Associated model config IDs */
|
/** Associated model config IDs */
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:24
|
* @Date: 2026-02-03 14:10:24
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-09 18:02:13
|
* @Last Modified time: 2026-03-06 11:25:59
|
||||||
*/
|
*/
|
||||||
import { type FC, type ReactNode } from 'react';
|
import { type FC, type ReactNode } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
@@ -17,7 +17,7 @@ const { Header } = Layout;
|
|||||||
*/
|
*/
|
||||||
interface ConfigHeaderProps {
|
interface ConfigHeaderProps {
|
||||||
/** Page title/name */
|
/** Page title/name */
|
||||||
name?: string;
|
name?: string | ReactNode;
|
||||||
/** Subtitle content displayed below the title */
|
/** Subtitle content displayed below the title */
|
||||||
subTitle?: ReactNode | string;
|
subTitle?: ReactNode | string;
|
||||||
/** Extra content displayed on the right side */
|
/** Extra content displayed on the right side */
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:15
|
* @Date: 2026-02-03 14:10:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-05 16:28:53
|
* @Last Modified time: 2026-03-06 10:56:44
|
||||||
*/
|
*/
|
||||||
import { type FC, useState, useRef, type MouseEvent } from 'react';
|
import { type FC, useState, useRef, type MouseEvent } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
@@ -181,8 +181,8 @@ const Ontology: FC = () => {
|
|||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<div className="rb:mt-4 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
<div className="rb:mt-4 rb:h-5 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
||||||
<Space size={16}>
|
{!item.is_system_default && <Space size={16}>
|
||||||
<div
|
<div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
||||||
onClick={(e) => handleEdit(item, e)}
|
onClick={(e) => handleEdit(item, e)}
|
||||||
@@ -191,7 +191,7 @@ const Ontology: FC = () => {
|
|||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={(e) => handleDelete(item, e)}
|
onClick={(e) => handleDelete(item, e)}
|
||||||
></div>
|
></div>
|
||||||
</Space>
|
</Space>}
|
||||||
</div>
|
</div>
|
||||||
</RbCard>
|
</RbCard>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:20
|
* @Date: 2026-02-03 14:10:20
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-09 17:56:35
|
* @Last Modified time: 2026-03-06 11:26:49
|
||||||
*/
|
*/
|
||||||
import { type FC, useEffect, useState, useRef } from 'react'
|
import { type FC, useEffect, useState, useRef } from 'react'
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
@@ -17,6 +17,7 @@ import OntologyClassModal from '../components/OntologyClassModal'
|
|||||||
import SearchInput from '@/components/SearchInput';
|
import SearchInput from '@/components/SearchInput';
|
||||||
import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
|
import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
|
||||||
import BodyWrapper from '@/components/Empty/BodyWrapper'
|
import BodyWrapper from '@/components/Empty/BodyWrapper'
|
||||||
|
import Tag from '@/components/Tag'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ontology detail page component
|
* Ontology detail page component
|
||||||
@@ -99,19 +100,22 @@ const Detail: FC = () => {
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<PageHeader
|
<PageHeader
|
||||||
name={data.scene_name}
|
name={<Space>
|
||||||
|
{data.scene_name}
|
||||||
|
{data.is_system_default ? <Tag color="warning">{t('common.default')}</Tag> : undefined}
|
||||||
|
</Space>}
|
||||||
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
||||||
extra={<Space>
|
extra={data.is_system_default ? undefined : (<Space>
|
||||||
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
|
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
|
||||||
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
|
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
|
||||||
</Space>}
|
</Space>)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
|
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
|
||||||
<Row gutter={16} className="rb:mb-4">
|
<Row gutter={16} className="rb:mb-4">
|
||||||
<Col span={6} offset={18}>
|
<Col span={6} offset={18}>
|
||||||
<SearchInput
|
<SearchInput
|
||||||
placeholder={t('ontology.searchPlaceholder')}
|
placeholder={t('ontology.classSearchPlaceholder')}
|
||||||
onSearch={(value) => setQuery({ class_name: value })}
|
onSearch={(value) => setQuery({ class_name: value })}
|
||||||
className="rb:w-full!"
|
className="rb:w-full!"
|
||||||
/>
|
/>
|
||||||
@@ -123,10 +127,10 @@ const Detail: FC = () => {
|
|||||||
<Col key={item.class_id} span={6}>
|
<Col key={item.class_id} span={6}>
|
||||||
<RbCard
|
<RbCard
|
||||||
title={item.class_name}
|
title={item.class_name}
|
||||||
extra={<div
|
extra={data.is_system_default ? undefined : (<div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={() => handleDelete(item)}
|
onClick={() => handleDelete(item)}
|
||||||
></div>}
|
></div>)}
|
||||||
className="rb:bg-transparent!"
|
className="rb:bg-transparent!"
|
||||||
>
|
>
|
||||||
<Tooltip title={item.class_description}>
|
<Tooltip title={item.class_description}>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:10
|
* @Date: 2026-02-03 14:10:10
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-05 16:18:56
|
* @Last Modified time: 2026-03-06 10:55:23
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Query parameters for ontology list pagination and filtering
|
* Query parameters for ontology list pagination and filtering
|
||||||
@@ -94,6 +94,7 @@ export interface OntologyClassData {
|
|||||||
scene_description: string;
|
scene_description: string;
|
||||||
/** Array of class items */
|
/** Array of class items */
|
||||||
items: OntologyClassItem[];
|
items: OntologyClassItem[];
|
||||||
|
is_system_default: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,131 +1,296 @@
|
|||||||
import React, { useState, useRef, type ReactNode } from 'react';
|
import React, { useState, useRef, useEffect, useCallback, type ReactNode } from 'react';
|
||||||
import { Input, Button, Spin, App } from 'antd';
|
import { Input, Button, App, Card, Space, Skeleton, Tag } from 'antd';
|
||||||
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
|
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import InfiniteScroll from 'react-infinite-scroll-component';
|
||||||
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
|
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
|
||||||
|
import McpServiceModal from './components/McpServiceModal';
|
||||||
|
import type { McpServiceModalRef } from './types';
|
||||||
|
import pageEmptyIcon from '@/assets/images/empty/pageEmpty.png'
|
||||||
|
import Empty from '@/components/Empty/index'
|
||||||
|
import { getMarketTools, getMarketConfig, getMarketMCPs, getMarketMCPDetail, getMarketMCPsActivated, getTools } from '@/api/tools';
|
||||||
|
import BodyWrapper from '@/components/Empty/BodyWrapper';
|
||||||
interface MarketSource {
|
interface MarketSource {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
category: string;
|
category: string;
|
||||||
icon: string;
|
logo_url: string;
|
||||||
url: string;
|
url: string;
|
||||||
desc: string;
|
description: string;
|
||||||
apiKey: string;
|
api_key?: string;
|
||||||
connected: boolean;
|
connected: boolean;
|
||||||
mcpCount: number;
|
mcp_count: number;
|
||||||
|
created_at?: number;
|
||||||
|
created_by?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MarketMcp {
|
interface MarketMcp {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
provider: string;
|
chinese_name?: string;
|
||||||
type: string;
|
description: string;
|
||||||
desc: string;
|
logo_url: string;
|
||||||
downloads?: string;
|
publisher: string;
|
||||||
stars?: string;
|
categories?: string[];
|
||||||
icon: string;
|
tags?: string[];
|
||||||
configTemplate: any;
|
view_count?: number;
|
||||||
|
activated?: boolean;
|
||||||
|
inDatabase?: boolean;
|
||||||
|
locales?: {
|
||||||
|
[lang: string]: {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MarketCategory {
|
interface MarketCategory {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
icon: string;
|
}
|
||||||
|
|
||||||
|
interface MarketApiResponse {
|
||||||
|
items: MarketSource[];
|
||||||
}
|
}
|
||||||
|
|
||||||
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
|
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
|
||||||
const { t } = useTranslation();
|
const { t, i18n } = useTranslation();
|
||||||
const { message } = App.useApp();
|
const { message } = App.useApp();
|
||||||
|
|
||||||
|
const getLocaleField = (mcp: MarketMcp, field: 'name' | 'description') => {
|
||||||
|
const lang = i18n.language?.startsWith('zh') ? 'zh' : 'en';
|
||||||
|
return mcp.locales?.[lang]?.[field] || mcp[field] || '';
|
||||||
|
};
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [selectedSource, setSelectedSource] = useState<string | null>(null);
|
const [selectedSource, setSelectedSource] = useState<string | null>(null);
|
||||||
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
|
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
|
||||||
const [marketSources, setMarketSources] = useState<MarketSource[]>([
|
const mcpServiceModalRef = useRef<McpServiceModalRef>(null);
|
||||||
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 },
|
const [marketSources, setMarketSources] = useState<MarketSource[]>([]);
|
||||||
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 },
|
const [categories, setCategories] = useState<MarketCategory[]>([]);
|
||||||
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 },
|
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({});
|
||||||
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 },
|
const [mcpTotal, setMcpTotal] = useState(0);
|
||||||
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
|
|
||||||
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
|
|
||||||
]);
|
|
||||||
|
|
||||||
const [categories] = useState<MarketCategory[]>([
|
|
||||||
{ id: 'official', name: '官方/综合', icon: '🌐' },
|
|
||||||
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
|
|
||||||
{ id: 'community', name: '社区/垂直', icon: '👥' }
|
|
||||||
]);
|
|
||||||
|
|
||||||
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
|
|
||||||
'github-mcp': [
|
|
||||||
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
|
|
||||||
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
|
|
||||||
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成,支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
|
|
||||||
]
|
|
||||||
});
|
|
||||||
|
|
||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
|
const [configIdMap, setConfigIdMap] = useState<Record<string, string>>({});
|
||||||
|
const [hasMore, setHasMore] = useState(false);
|
||||||
|
const [activatedMcps, setActivatedMcps] = useState<string[]>([]);
|
||||||
|
const [currentPage, setCurrentPage] = useState(1);
|
||||||
|
const pageSize = 20;
|
||||||
|
|
||||||
const handleSelectSource = (sourceId: string) => {
|
// 获取市场数据
|
||||||
setSelectedSource(sourceId);
|
useEffect(() => {
|
||||||
};
|
const fetchMarketData = async () => {
|
||||||
|
setLoading(true);
|
||||||
const handleRefresh = (sourceId: string) => {
|
try {
|
||||||
setLoading(true);
|
const response = await getMarketTools({}) as MarketApiResponse;
|
||||||
setTimeout(() => {
|
if (response?.items && Array.isArray(response.items)) {
|
||||||
// 模拟刷新数据
|
setMarketSources(response.items);
|
||||||
const source = marketSources.find(s => s.id === sourceId);
|
|
||||||
if (source) {
|
// 根据 category 字段分组
|
||||||
message.success(`${source.name} 列表已刷新`);
|
const categoryMap = new Map<string, MarketCategory>();
|
||||||
|
response.items.forEach(item => {
|
||||||
|
if (item.category && !categoryMap.has(item.category)) {
|
||||||
|
categoryMap.set(item.category, {
|
||||||
|
id: item.category,
|
||||||
|
name: item.category
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
setCategories(Array.from(categoryMap.values()));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('获取市场数据失败:', error);
|
||||||
|
message.error('获取市场数据失败');
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchMarketData();
|
||||||
|
}, [message]);
|
||||||
|
|
||||||
|
const fetchMcpList = async (sourceId: string, page = 1, append = false) => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
let configId = configIdMap[sourceId];
|
||||||
|
|
||||||
|
// 如果没有缓存 configId,先获取配置
|
||||||
|
if (!configId) {
|
||||||
|
const config: any = await getMarketConfig(sourceId);
|
||||||
|
if (config?.id) {
|
||||||
|
configId = config.id;
|
||||||
|
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次加载时获取已激活列表
|
||||||
|
let activatedIds: string[] = activatedMcps;
|
||||||
|
if (page === 1 && !append) {
|
||||||
|
const activatedRes: any = await getMarketMCPsActivated({ mcp_market_config_id: configId });
|
||||||
|
if (activatedRes && Array.isArray(activatedRes)) {
|
||||||
|
activatedIds = activatedRes.map((item: any) => item.id);
|
||||||
|
setActivatedMcps(activatedIds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取全量工具列表,用于标记已入库的 MCP
|
||||||
|
const allTools: any = await getTools({ tool_type: 'mcp' });
|
||||||
|
const toolsList = Array.isArray(allTools) ? allTools : [];
|
||||||
|
|
||||||
|
const res: any = await getMarketMCPs({ mcp_market_config_id: configId, page, pagesize: pageSize });
|
||||||
|
if (res?.items && Array.isArray(res.items)) {
|
||||||
|
// 标记已激活和已入库的 MCP
|
||||||
|
const mcpsWithActivated = res.items.map((item: MarketMcp) => {
|
||||||
|
// 检查是否已入库:market_id = sourceId, market_config_id = configId, mcp_service_id = item.id
|
||||||
|
const isInDatabase = toolsList.some((tool: any) =>
|
||||||
|
tool.config_data?.market_id === sourceId &&
|
||||||
|
tool.config_data?.market_config_id === configId &&
|
||||||
|
tool.config_data?.mcp_service_id === item.id
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
...item,
|
||||||
|
activated: activatedIds.includes(item.id),
|
||||||
|
inDatabase: isInDatabase
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
setMcpCache(prev => ({
|
||||||
|
...prev,
|
||||||
|
[sourceId]: append ? [...(prev[sourceId] || []), ...mcpsWithActivated] : mcpsWithActivated
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
if (res?.page) {
|
||||||
|
setMcpTotal(res.page.total || 0);
|
||||||
|
setHasMore(!!res.page.has_next);
|
||||||
|
setCurrentPage(res.page.page || page);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('获取 MCP 列表失败:', error);
|
||||||
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}, 600);
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleOpenConfig = (sourceId: string) => {
|
const loadMore = useCallback(() => {
|
||||||
|
if (!selectedSource || loading) return;
|
||||||
|
fetchMcpList(selectedSource, currentPage + 1, true);
|
||||||
|
}, [selectedSource, currentPage, loading]);
|
||||||
|
|
||||||
|
const handleSelectSource = async (sourceId: string) => {
|
||||||
|
setSelectedSource(sourceId);
|
||||||
|
setSearchKeyword('');
|
||||||
|
setCurrentPage(1);
|
||||||
|
setHasMore(false);
|
||||||
|
setMcpTotal(0);
|
||||||
|
|
||||||
|
// 如果缓存中已有数据,直接使用
|
||||||
|
if (mcpCache[sourceId]) return;
|
||||||
|
|
||||||
|
await fetchMcpList(sourceId, 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRefresh = async (sourceId: string) => {
|
||||||
|
// 清除缓存,重新从第一页加载
|
||||||
|
setMcpCache(prev => {
|
||||||
|
const next = { ...prev };
|
||||||
|
delete next[sourceId];
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
setCurrentPage(1);
|
||||||
|
await fetchMcpList(sourceId, 1);
|
||||||
const source = marketSources.find(s => s.id === sourceId);
|
const source = marketSources.find(s => s.id === sourceId);
|
||||||
if (source) {
|
if (source) {
|
||||||
|
message.success(`${source.name} ${t('tool.marketRefreshSuccess')}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpenConfig = async (sourceId: string) => {
|
||||||
|
const source = marketSources.find(s => s.id === sourceId);
|
||||||
|
if (!source) return;
|
||||||
|
try {
|
||||||
|
const config: any = await getMarketConfig(sourceId);
|
||||||
|
marketConfigModalRef.current?.handleOpen({
|
||||||
|
...source,
|
||||||
|
connected: config?.status === 1,
|
||||||
|
token: config?.token || '',
|
||||||
|
configId: config?.id || '',
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
marketConfigModalRef.current?.handleOpen(source);
|
marketConfigModalRef.current?.handleOpen(source);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleConnect = (sourceId: string, apiKey: string) => {
|
const handleOpenMcpServiceModal = async (mcp: MarketMcp) => {
|
||||||
// 更新市场源状态
|
if (!selectedSource || !configIdMap[selectedSource]) return;
|
||||||
|
try {
|
||||||
|
const detail: any = await getMarketMCPDetail({
|
||||||
|
mcp_market_config_id: configIdMap[selectedSource],
|
||||||
|
server_id: mcp.id,
|
||||||
|
});
|
||||||
|
const source = marketSources.find(s => s.id === selectedSource);
|
||||||
|
const toolItem = {
|
||||||
|
name: detail.name,
|
||||||
|
description: detail.description,
|
||||||
|
source_channel: source?.name || '',
|
||||||
|
market_id: selectedSource,
|
||||||
|
market_config_id: configIdMap[selectedSource],
|
||||||
|
mcp_service_id: mcp.id,
|
||||||
|
config_data: {
|
||||||
|
server_url: detail.servers?.[0]?.url || '',
|
||||||
|
connection_config: {
|
||||||
|
auth_type: 'none',
|
||||||
|
timeout: 30,
|
||||||
|
headers: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
mcpServiceModalRef.current?.handleOpen(toolItem as any);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('获取 MCP 服务详情失败:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleConnect = async (sourceId: string, configId: string) => {
|
||||||
|
// 更新市场源状态,缓存 configId
|
||||||
setMarketSources(prev => prev.map(source => {
|
setMarketSources(prev => prev.map(source => {
|
||||||
if (source.id === sourceId) {
|
if (source.id === sourceId) {
|
||||||
return {
|
return { ...source, connected: true };
|
||||||
...source,
|
|
||||||
apiKey,
|
|
||||||
connected: true
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
return source;
|
return source;
|
||||||
}));
|
}));
|
||||||
|
setConfigIdMap(prev => ({ ...prev, [sourceId]: configId }));
|
||||||
|
|
||||||
// 模拟获取MCP列表
|
// 用 configId 获取第一页 MCP 列表
|
||||||
setTimeout(() => {
|
try {
|
||||||
const source = marketSources.find(s => s.id === sourceId);
|
const res: any = await getMarketMCPs({ mcp_market_config_id: configId, page: 1, pagesize: pageSize });
|
||||||
if (source && !mcpCache[sourceId]) {
|
if (res?.items && Array.isArray(res.items)) {
|
||||||
// 生成模拟数据
|
setMcpCache(prev => ({ ...prev, [sourceId]: res.items }));
|
||||||
const mockData: MarketMcp[] = [
|
|
||||||
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
|
|
||||||
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
|
|
||||||
];
|
|
||||||
setMcpCache(prev => ({
|
|
||||||
...prev,
|
|
||||||
[sourceId]: mockData
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
message.success(`已连接 ${source?.name}`);
|
if (res?.page) {
|
||||||
}, 800);
|
setMcpTotal(res.page.total || 0);
|
||||||
|
setHasMore(!!res.page.has_next);
|
||||||
|
setCurrentPage(1);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('获取 MCP 列表失败:', error);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderSourceDetail = () => {
|
const renderSourceDetail = () => {
|
||||||
if (!selectedSource) {
|
if (!selectedSource) {
|
||||||
return (
|
return (
|
||||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
|
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
|
||||||
<div className="rb:text-6xl rb:mb-4">🏪</div>
|
<Empty
|
||||||
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2">选择一个 MCP 市场</h3>
|
url={pageEmptyIcon}
|
||||||
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md">从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务</p>
|
title={t('tool.marketSelectTitle')}
|
||||||
|
subTitle={t('tool.marketSelectDesc')}
|
||||||
|
size={200}
|
||||||
|
className="rb:h-full"
|
||||||
|
/>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -134,170 +299,218 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
|||||||
if (!source) return null;
|
if (!source) return null;
|
||||||
|
|
||||||
const mcpList = mcpCache[selectedSource] || [];
|
const mcpList = mcpCache[selectedSource] || [];
|
||||||
const filteredList = mcpList.filter(mcp =>
|
const filteredList = mcpList.filter(mcp => {
|
||||||
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
|
const name = getLocaleField(mcp, 'name');
|
||||||
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase())
|
const desc = getLocaleField(mcp, 'description');
|
||||||
);
|
return name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
|
||||||
|
desc.toLowerCase().includes(searchKeyword.toLowerCase());
|
||||||
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6">
|
<div className="rb:flex rb:justify-between rb:items-center rb:pb-0">
|
||||||
<div className="rb:flex rb:gap-4">
|
<div className="rb:flex rb:items-center rb:gap-4">
|
||||||
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0">
|
<div className="rb:w-10 rb:h-10 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0 rb:overflow-hidden">
|
||||||
{source.icon}
|
{source.logo_url ? (
|
||||||
|
<img
|
||||||
|
src={source.logo_url}
|
||||||
|
alt={source.name}
|
||||||
|
className="rb:w-full rb:h-full rb:object-cover"
|
||||||
|
referrerPolicy="no-referrer"
|
||||||
|
onError={(e) => {
|
||||||
|
e.currentTarget.style.display = 'none';
|
||||||
|
const parent = e.currentTarget.parentElement;
|
||||||
|
if (parent) {
|
||||||
|
parent.innerHTML = '🏪';
|
||||||
|
parent.style.fontSize = '48px';
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<span className="rb:text-5xl">🏪</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:flex-1">
|
<div className="rb:flex rb:items-center rb:flex-1">
|
||||||
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2>
|
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2 rb:mr-2">{source.name}</h2>
|
||||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p>
|
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpTotal})</span>
|
||||||
|
{/* <p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.description}</p> */}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:flex rb:gap-3">
|
|
||||||
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
|
||||||
配置
|
|
||||||
</Button>
|
|
||||||
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
|
|
||||||
前往市场
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="rb:mt-6">
|
<div className="rb:flex rb:gap-3">
|
||||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
|
|
||||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
|
|
||||||
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
|
|
||||||
</h3>
|
|
||||||
<div className="rb:flex rb:gap-3 rb:items-center">
|
<div className="rb:flex rb:gap-3 rb:items-center">
|
||||||
{source.connected && (
|
{source.connected && (
|
||||||
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
||||||
刷新
|
{t('tool.marketRefresh')}
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
{mcpList.length > 0 && (
|
{mcpList.length > 0 && (
|
||||||
<Input
|
<Input
|
||||||
prefix={<SearchOutlined />}
|
prefix={<SearchOutlined />}
|
||||||
placeholder="搜索服务..."
|
placeholder={t('tool.marketSearchPlaceholder')}
|
||||||
value={searchKeyword}
|
value={searchKeyword}
|
||||||
onChange={(e) => setSearchKeyword(e.target.value)}
|
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||||
style={{ width: 200 }}
|
style={{ width: 200 }}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
||||||
|
{t('tool.marketConfig')}
|
||||||
|
</Button>
|
||||||
|
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
|
||||||
|
{t('tool.marketVisit')}
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
{mcpList.length > 0 ? (
|
<div className="rb:mt-6">
|
||||||
<Spin spinning={loading}>
|
<BodyWrapper loading={loading} empty={mcpList.length === 0}>
|
||||||
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4">
|
<div id="mcpScrollableDiv" className="rb:overflow-y-auto rb:h-[calc(100vh-260px)]">
|
||||||
{filteredList.map(mcp => (
|
<InfiniteScroll
|
||||||
|
dataLength={filteredList.length}
|
||||||
|
next={loadMore}
|
||||||
|
hasMore={hasMore}
|
||||||
|
loader={<Skeleton active paragraph={{ rows: 2 }} className="rb:mt-4" />}
|
||||||
|
scrollableTarget="mcpScrollableDiv"
|
||||||
|
>
|
||||||
|
<div className="rb:grid rb:grid-cols-3 rb:gap-4">
|
||||||
|
{filteredList.map(mcp => (
|
||||||
<div
|
<div
|
||||||
key={mcp.id}
|
key={mcp.id}
|
||||||
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:pb-2 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
||||||
>
|
>
|
||||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
|
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
|
||||||
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg">
|
<div className="rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg rb:overflow-hidden">
|
||||||
{mcp.icon}
|
{mcp.logo_url ? (
|
||||||
|
<img
|
||||||
|
src={mcp.logo_url}
|
||||||
|
alt={getLocaleField(mcp, 'name')}
|
||||||
|
className="rb:w-full rb:h-full rb:object-cover"
|
||||||
|
referrerPolicy="no-referrer"
|
||||||
|
onError={(e) => {
|
||||||
|
e.currentTarget.style.display = 'none';
|
||||||
|
const parent = e.currentTarget.parentElement;
|
||||||
|
if (parent) {
|
||||||
|
parent.innerHTML = '🔧';
|
||||||
|
parent.style.fontSize = '24px';
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<span className="rb:text-3xl">🔧</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${
|
{mcp.categories?.[0] && (
|
||||||
mcp.type === 'Hosted'
|
<span className="rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium rb:bg-blue-50 rb:text-blue-700">
|
||||||
? 'rb:bg-blue-50 rb:text-blue-700'
|
{mcp.categories[0]}
|
||||||
: 'rb:bg-gray-100 rb:text-gray-600'
|
</span>
|
||||||
}`}>
|
)}
|
||||||
{mcp.type}
|
|
||||||
</span>
|
|
||||||
</div>
|
</div>
|
||||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3>
|
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{getLocaleField(mcp, 'name')}</h3>
|
||||||
{mcp.provider && (
|
{mcp.publisher && (
|
||||||
<div className="rb:mb-2">
|
<div className="rb:mb-2">
|
||||||
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span>
|
<span className="rb:text-xs rb:text-gray-500">{mcp.publisher.startsWith('@') ? mcp.publisher : `@${mcp.publisher}`}</span>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p>
|
<p className="rb:text-sm rb:text-gray-600 rb:line-clamp-2 rb:mb-3 rb:min-h-10">{getLocaleField(mcp, 'description')}</p>
|
||||||
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
|
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
|
||||||
{mcp.downloads && (
|
{mcp.view_count != null && (
|
||||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||||
<GlobalOutlined /> {mcp.downloads}
|
<GlobalOutlined /> {mcp.view_count.toLocaleString()}
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
{mcp.stars && (
|
|
||||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
|
||||||
⭐ {mcp.stars}
|
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:flex rb:justify-end">
|
<div className={`rb:flex rb:items-center ${mcp.activated || mcp.inDatabase ? 'rb:justify-between' : 'rb:justify-end'}`}>
|
||||||
<Button type="primary" size="small">
|
<div className="rb:flex rb:gap-2">
|
||||||
+ 添加
|
{mcp.activated && <Tag color="success">{t('tool.marketActivated')}</Tag>}
|
||||||
|
{mcp.inDatabase && <Tag color="blue">{t('tool.marketInDatabase')}</Tag>}
|
||||||
|
</div>
|
||||||
|
<Button type="primary" size="small" onClick={() => handleOpenMcpServiceModal(mcp)}>
|
||||||
|
+ {t('tool.marketAdd')}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
</Spin>
|
</InfiniteScroll>
|
||||||
) : (
|
|
||||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
|
|
||||||
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
|
|
||||||
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
|
|
||||||
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
|
|
||||||
</h4>
|
|
||||||
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
|
|
||||||
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
|
|
||||||
</p>
|
|
||||||
{!source.connected && (
|
|
||||||
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
|
|
||||||
配置连接
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
</BodyWrapper>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]">
|
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-138px)]">
|
||||||
{/* 左侧市场源列表 */}
|
{/* 左侧市场源列表 */}
|
||||||
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0">
|
<div className="rb:w-80 rb:h-full rb:overflow-y-auto">
|
||||||
<div className="rb:p-4 rb:border-b rb:border-gray-200">
|
<Space size={12} direction="vertical" className="rb:w-full">
|
||||||
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP 市场</span>
|
{categories.map(cat => (
|
||||||
</div>
|
<Card
|
||||||
{categories.map(cat => (
|
key={cat.id}
|
||||||
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0">
|
type="inner"
|
||||||
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase">
|
title={
|
||||||
<span className="rb:text-sm">{cat.icon}</span>
|
<div className="rb:flex rb:items-center rb:gap-2">
|
||||||
<span>{cat.name}</span>
|
<span>{cat.name}</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:px-2 rb:py-1">
|
}
|
||||||
{marketSources
|
classNames={{
|
||||||
.filter(s => s.category === cat.id)
|
body: "rb:p-[10px]!",
|
||||||
.map(source => (
|
header: "rb:bg-[#F6F8FC]!"
|
||||||
<div
|
}}
|
||||||
key={source.id}
|
>
|
||||||
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${
|
<Space size={8} direction="vertical" className="rb:w-full">
|
||||||
selectedSource === source.id
|
{marketSources
|
||||||
? 'rb:bg-blue-50 rb:text-blue-600'
|
.filter(s => s.category === cat.id)
|
||||||
: 'hover:rb:bg-gray-50'
|
.map(source => (
|
||||||
}`}
|
<div
|
||||||
onClick={() => handleSelectSource(source.id)}
|
key={source.id}
|
||||||
>
|
className={`rb:bg-white rb:rounded-lg rb:p-2 rb:border rb:cursor-pointer rb:flex rb:items-center rb:gap-2 rb:transition-all ${
|
||||||
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span>
|
selectedSource === source.id
|
||||||
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
? 'rb:border-[#155EEF] rb:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
|
||||||
{source.name}
|
: 'rb:border-[#DFE4ED] rb:hover:border-[#155EEF] rb:hover:shadow-[0px_2px_4px_0px_rgba(33,35,50,0.15)]'
|
||||||
</span>
|
}`}
|
||||||
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full">
|
onClick={() => handleSelectSource(source.id)}
|
||||||
{source.mcpCount}
|
>
|
||||||
</span>
|
<div className="rb:w-5 rb:h-5 rb:flex-shrink-0 rb:flex rb:items-center rb:justify-center rb:overflow-hidden rb:rounded rb:bg-gray-100">
|
||||||
{source.connected && (
|
{source.logo_url ? (
|
||||||
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1">●</span>
|
<img
|
||||||
)}
|
src={source.logo_url}
|
||||||
</div>
|
alt={source.name}
|
||||||
))}
|
className="rb:w-full rb:h-full rb:object-cover"
|
||||||
</div>
|
referrerPolicy="no-referrer"
|
||||||
</div>
|
onError={(e) => {
|
||||||
))}
|
e.currentTarget.style.display = 'none';
|
||||||
|
const parent = e.currentTarget.parentElement;
|
||||||
|
if (parent) {
|
||||||
|
parent.innerHTML = '🏪';
|
||||||
|
parent.style.fontSize = '16px';
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<span className="rb:text-base">🏪</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<span className="rb:flex-1 rb:font-medium rb:text-[12px] rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||||
|
{source.name}
|
||||||
|
</span>
|
||||||
|
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full rb:flex-shrink-0">
|
||||||
|
{source.mcp_count}
|
||||||
|
</span>
|
||||||
|
{source.connected && (
|
||||||
|
<span className="rb:text-green-500 rb:text-[8px] rb:flex-shrink-0">●</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 右侧内容区 */}
|
{/* 右侧内容区 */}
|
||||||
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden">
|
<div className="rb:flex-1 rb:border-l rb:border-gray-200 rb:overflow-hidden">
|
||||||
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
|
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
|
||||||
{renderSourceDetail()}
|
{renderSourceDetail()}
|
||||||
</div>
|
</div>
|
||||||
@@ -308,6 +521,10 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
|||||||
ref={marketConfigModalRef}
|
ref={marketConfigModalRef}
|
||||||
onConnect={handleConnect}
|
onConnect={handleConnect}
|
||||||
/>
|
/>
|
||||||
|
<McpServiceModal
|
||||||
|
ref={mcpServiceModalRef}
|
||||||
|
refresh={() => {}}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ const Mcp: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getSta
|
|||||||
getData()
|
getData()
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
// 删除服务
|
// 删除服务
|
||||||
const handleDeleteService = (item: ToolItem) => {
|
const handleDeleteService = (item: ToolItem) => {
|
||||||
if (!item.id) {
|
if (!item.id) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react';
|
|||||||
import { Form, Input, Button, App, Space } from 'antd';
|
import { Form, Input, Button, App, Space } from 'antd';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||||
|
import { createMarketConfig,updateMarketConfig } from '@/api/tools';
|
||||||
import RbModal from '@/components/RbModal';
|
import RbModal from '@/components/RbModal';
|
||||||
|
|
||||||
const FormItem = Form.Item;
|
const FormItem = Form.Item;
|
||||||
@@ -9,15 +10,16 @@ const FormItem = Form.Item;
|
|||||||
interface MarketSource {
|
interface MarketSource {
|
||||||
id: string;
|
id: string;
|
||||||
name: string;
|
name: string;
|
||||||
icon: string;
|
logo_url: string;
|
||||||
url: string;
|
url: string;
|
||||||
desc: string;
|
description: string;
|
||||||
apiKey: string;
|
token?: string;
|
||||||
connected: boolean;
|
connected: boolean;
|
||||||
|
configId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MarketConfigModalProps {
|
interface MarketConfigModalProps {
|
||||||
onConnect: (sourceId: string, apiKey: string) => void;
|
onConnect: (sourceId: string, configId: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface MarketConfigModalRef {
|
export interface MarketConfigModalRef {
|
||||||
@@ -47,8 +49,7 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
const handleOpen = (source: MarketSource) => {
|
const handleOpen = (source: MarketSource) => {
|
||||||
setCurrentSource(source);
|
setCurrentSource(source);
|
||||||
form.setFieldsValue({
|
form.setFieldsValue({
|
||||||
url: source.url,
|
token: source.token || '',
|
||||||
apiKey: source.apiKey,
|
|
||||||
});
|
});
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
};
|
};
|
||||||
@@ -56,18 +57,36 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
const handleSave = () => {
|
const handleSave = () => {
|
||||||
form
|
form
|
||||||
.validateFields()
|
.validateFields()
|
||||||
.then((values) => {
|
.then(async (values) => {
|
||||||
if (!currentSource) return;
|
if (!currentSource) return;
|
||||||
|
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
|
try {
|
||||||
// 模拟连接延迟
|
let res: any;
|
||||||
setTimeout(() => {
|
if (currentSource.configId) {
|
||||||
onConnect(currentSource.id, values.apiKey || '');
|
// 更新配置
|
||||||
message.success(`正在连接 ${currentSource.name}...`);
|
res = await updateMarketConfig({
|
||||||
setLoading(false);
|
mcp_market_config_id: currentSource.configId,
|
||||||
|
token: values.token || '',
|
||||||
|
status: 1,
|
||||||
|
});
|
||||||
|
message.success(t('tool.marketConfigUpdated', { name: currentSource.name }));
|
||||||
|
} else {
|
||||||
|
// 创建配置
|
||||||
|
res = await createMarketConfig({
|
||||||
|
mcp_market_id: currentSource.id || '',
|
||||||
|
token: values.token || '',
|
||||||
|
status: 1,
|
||||||
|
});
|
||||||
|
message.success(t('tool.marketConnecting', { name: currentSource.name }));
|
||||||
|
}
|
||||||
|
onConnect(currentSource.id, res.id || currentSource.configId);
|
||||||
handleClose();
|
handleClose();
|
||||||
}, 500);
|
} catch (error) {
|
||||||
|
console.error('保存配置失败:', error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
console.log('表单验证失败:', err);
|
console.log('表单验证失败:', err);
|
||||||
@@ -91,10 +110,10 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<RbModal
|
<RbModal
|
||||||
title={`配置 ${currentSource.name}`}
|
title={t('tool.marketConfig', { name: currentSource.name })}
|
||||||
open={visible}
|
open={visible}
|
||||||
onCancel={handleClose}
|
onCancel={handleClose}
|
||||||
okText="保存并连接"
|
okText={t('tool.marketSaveAndConnect')}
|
||||||
onOk={handleSave}
|
onOk={handleSave}
|
||||||
confirmLoading={loading}
|
confirmLoading={loading}
|
||||||
width={600}
|
width={600}
|
||||||
@@ -102,12 +121,28 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
<div>
|
<div>
|
||||||
{/* 市场源信息头部 */}
|
{/* 市场源信息头部 */}
|
||||||
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
|
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
|
||||||
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0">
|
<div className="rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0 rb:overflow-hidden">
|
||||||
{currentSource.icon}
|
{currentSource.logo_url ? (
|
||||||
|
<img
|
||||||
|
src={currentSource.logo_url}
|
||||||
|
alt={currentSource.name}
|
||||||
|
className="rb:w-full rb:h-full rb:object-cover"
|
||||||
|
onError={(e) => {
|
||||||
|
e.currentTarget.style.display = 'none';
|
||||||
|
const parent = e.currentTarget.parentElement;
|
||||||
|
if (parent) {
|
||||||
|
parent.innerHTML = '🏪';
|
||||||
|
parent.style.fontSize = '32px';
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<span className="rb:text-4xl">🏪</span>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:flex-1">
|
<div className="rb:flex-1">
|
||||||
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
|
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
|
||||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p>
|
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.description}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -115,39 +150,34 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
form={form}
|
form={form}
|
||||||
layout="vertical"
|
layout="vertical"
|
||||||
>
|
>
|
||||||
{/* 市场地址 */}
|
<FormItem label={t('tool.marketUrl')}>
|
||||||
<FormItem
|
|
||||||
name="url"
|
|
||||||
label="市场地址"
|
|
||||||
>
|
|
||||||
<Space.Compact style={{ width: '100%' }}>
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
<Input
|
<Input
|
||||||
readOnly
|
readOnly
|
||||||
placeholder="市场地址"
|
value={currentSource.url}
|
||||||
/>
|
/>
|
||||||
<Button
|
<Button
|
||||||
icon={<CopyOutlined />}
|
icon={<CopyOutlined />}
|
||||||
onClick={handleCopyUrl}
|
onClick={handleCopyUrl}
|
||||||
>
|
>
|
||||||
复制
|
{t('tool.marketCopy')}
|
||||||
</Button>
|
</Button>
|
||||||
</Space.Compact>
|
</Space.Compact>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|
||||||
{/* API Key */}
|
|
||||||
<FormItem
|
<FormItem
|
||||||
name="apiKey"
|
name="token"
|
||||||
label={
|
label={
|
||||||
<span>
|
<span>
|
||||||
API Key <span className="rb:text-gray-400 rb:font-normal">(可选)</span>
|
API Key <span className="rb:text-gray-400 rb:font-normal">({t('tool.marketApiKeyOptional')})</span>
|
||||||
</span>
|
</span>
|
||||||
}
|
}
|
||||||
extra="部分市场需要 API Key 才能获取完整的服务列表"
|
extra={<span style={{ display: 'inline-block', marginTop: 8 }}>{t('tool.marketApiKeyExtra')}</span>}
|
||||||
>
|
>
|
||||||
<Space.Compact style={{ width: '100%' }}>
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
<Input
|
<Input
|
||||||
type={showApiKey ? 'text' : 'password'}
|
type={showApiKey ? 'text' : 'password'}
|
||||||
placeholder="输入 API Key 以获取更多服务"
|
placeholder={t('tool.marketApiKeyPlaceholder')}
|
||||||
autoComplete="off"
|
autoComplete="off"
|
||||||
/>
|
/>
|
||||||
<Button
|
<Button
|
||||||
@@ -157,11 +187,10 @@ const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProp
|
|||||||
</Space.Compact>
|
</Space.Compact>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|
||||||
{/* 连接状态 */}
|
|
||||||
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
|
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
|
||||||
<span className="rb:text-gray-600">连接状态:</span>
|
<span className="rb:text-gray-600">{t('tool.marketConnectionStatus')}:</span>
|
||||||
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
|
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
|
||||||
{currentSource.connected ? '● 已连接' : '○ 未连接'}
|
{currentSource.connected ? t('tool.marketConnected') : t('tool.marketDisconnected')}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</Form>
|
</Form>
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user