Merge pull request #486 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
This commit is contained in:
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
1
api/app/cache/__init__.py
vendored
1
api/app/cache/__init__.py
vendored
@@ -2,7 +2,6 @@
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
|
||||
1
api/app/cache/memory/__init__.py
vendored
1
api/app/cache/memory/__init__.py
vendored
@@ -2,7 +2,6 @@
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .interest_memory import InterestMemoryCache
|
||||
|
||||
|
||||
@@ -1,27 +1,54 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from celery.schedules import crontab
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
|
||||
@@ -1,28 +1,29 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -55,7 +56,8 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
@@ -161,13 +163,15 @@ async def write_server(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
@@ -216,7 +220,8 @@ async def write_server_async(
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -292,7 +297,8 @@ async def read_server(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.end_user_id,
|
||||
@@ -306,7 +312,8 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
@@ -337,7 +344,8 @@ async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -350,9 +358,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -631,7 +636,8 @@ async def status_type(
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -85,6 +85,7 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -99,7 +100,29 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -289,7 +289,8 @@ async def extract_ontology(
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体场景
|
||||
|
||||
@@ -360,8 +361,18 @@ async def create_scene(
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||
@@ -661,7 +672,8 @@ async def get_scenes(
|
||||
async def create_class(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体类型
|
||||
|
||||
@@ -676,7 +688,7 @@ async def create_class(
|
||||
ApiResponse: 包含创建的类型信息
|
||||
"""
|
||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||
return await create_class_handler(request, db, current_user)
|
||||
return await create_class_handler(request, db, current_user, x_language_type)
|
||||
|
||||
|
||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -58,7 +58,7 @@ async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -71,14 +71,14 @@ async def scenes_handler(
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -105,13 +105,13 @@ async def scenes_handler(
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
@@ -119,17 +119,15 @@ async def scenes_handler(
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -141,17 +139,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -165,28 +162,25 @@ async def scenes_handler(
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -198,17 +192,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -238,7 +231,8 @@ async def scenes_handler(
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
@@ -271,8 +265,11 @@ async def create_class_handler(
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
@@ -330,12 +327,36 @@ async def create_class_handler(
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
@@ -615,6 +636,7 @@ async def classes_handler(
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
|
||||
@@ -190,8 +190,10 @@ class Settings:
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,6 +52,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
@@ -171,6 +171,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -6,31 +6,26 @@ import os
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
@@ -50,6 +45,8 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
@@ -150,7 +146,6 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
@@ -167,6 +162,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
@@ -260,8 +256,6 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
async with SEMAPHORE: # 限制并发
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
@@ -51,6 +51,8 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -70,12 +72,15 @@ async def rag_knowledge(state,question):
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
"""
|
||||
@@ -99,6 +104,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
@@ -157,6 +163,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
@@ -169,6 +176,8 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -206,6 +215,7 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
@@ -224,7 +234,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
@@ -253,6 +264,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
@@ -328,6 +340,8 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
|
||||
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
@@ -100,7 +106,8 @@ async def Verify(state: ReadState):
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
|
||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
@@ -62,7 +60,6 @@ async def make_read_graph():
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
@@ -76,6 +73,7 @@ async def make_read_graph():
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
@@ -97,8 +95,10 @@ async def main():
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
@@ -141,7 +141,6 @@ async def main():
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
@@ -165,6 +164,8 @@ async def main():
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
@@ -174,4 +175,5 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db
|
||||
from app.db import get_db_context
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -57,7 +57,7 @@ class AgentNode(BaseNode):
|
||||
if not agent_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||
|
||||
db = next(get_db())
|
||||
with get_db_context() as db:
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
@@ -65,9 +65,9 @@ class AgentNode(BaseNode):
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
return release, message
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
@@ -79,9 +79,11 @@ class AgentNode(BaseNode):
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
@@ -118,13 +120,14 @@ class AgentNode(BaseNode):
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
# 执行 Agent(流式)
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
|
||||
@@ -374,7 +374,7 @@ class OntologySceneRepository:
|
||||
|
||||
count = self.db.query(OntologyScene).filter(
|
||||
OntologyScene.scene_id == scene_id,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
|
||||
).count()
|
||||
|
||||
is_owner = count > 0
|
||||
|
||||
@@ -116,8 +116,8 @@ class ModelApiKeyBase(BaseModel):
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
|
||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||
classes_count: int = Field(0, description="类型数量")
|
||||
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
|
||||
scene_id: UUID = Field(..., description="所属场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
if memory_content:
|
||||
memory_content = memory_content['answer']
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
|
||||
|
||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||
"""Extract tool call information from event"""
|
||||
last_message = event["messages"][-1]
|
||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -375,19 +374,18 @@ class MemoryAgentService:
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
# Use a separate database session to avoid transaction failures
|
||||
@@ -576,7 +574,8 @@ class MemoryAgentService:
|
||||
raw_results = intermediate.get('raw_results', {})
|
||||
try:
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
statements = [statement['statement'] for statement in
|
||||
reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
|
||||
@@ -602,7 +601,8 @@ class MemoryAgentService:
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
logger.debug(
|
||||
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
except Exception as save_error:
|
||||
# 保存失败不应该影响主流程,只记录错误
|
||||
@@ -610,7 +610,8 @@ class MemoryAgentService:
|
||||
|
||||
# Log successful operation
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
@@ -665,7 +665,8 @@ class MemoryAgentService:
|
||||
for idx, msg in enumerate(user_input.messages):
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
raise ValueError(
|
||||
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
if 'role' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||
@@ -673,7 +674,8 @@ class MemoryAgentService:
|
||||
|
||||
if 'content' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
raise ValueError(
|
||||
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
if msg['role'] not in ['user', 'assistant']:
|
||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||
@@ -719,6 +721,7 @@ class MemoryAgentService:
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
async def generate_summary_from_retrieve(
|
||||
self,
|
||||
end_user_id: str,
|
||||
@@ -805,13 +808,12 @@ class MemoryAgentService:
|
||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||
return "信息不足,无法回答。"
|
||||
|
||||
|
||||
async def get_knowledge_type_stats(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None,
|
||||
only_active: bool = True,
|
||||
current_workspace_id: Optional[uuid.UUID] = None,
|
||||
db: Session = None
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 统计 PostgreSQL 中的知识库类型
|
||||
try:
|
||||
if db is None:
|
||||
from app.db import get_db
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 初始化所有标准类型为 0
|
||||
for kb_type in KnowledgeType:
|
||||
result[kb_type.value] = 0
|
||||
@@ -889,8 +886,6 @@ class MemoryAgentService:
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
async def get_interest_distribution_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
@@ -921,7 +916,6 @@ class MemoryAgentService:
|
||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
tags: list[str] = Field(...,
|
||||
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
import json as json_module
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
"workspace_id": str(app.workspace_id)
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
logger.info(
|
||||
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -1,45 +1,42 @@
|
||||
# 修改 memory_konwledges_server.py 文件
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema, document_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from app.db import get_db_context
|
||||
from app.models.document_model import Document
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.db import get_db
|
||||
|
||||
# 创建一个简单的用户类用于测试
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class SimpleUser:
|
||||
def __init__(self, user_id: str):
|
||||
# 确保ID是UUID类型
|
||||
self.id = user_id
|
||||
self.username = user_id
|
||||
|
||||
'''解析'''
|
||||
|
||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||
"""
|
||||
解析指定文档
|
||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
'''获取块ID'''
|
||||
|
||||
async def get_document_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
||||
|
||||
return success(data=result, msg="文档块列表查询成功")
|
||||
|
||||
'''查找文档ID'''
|
||||
|
||||
def find_document_id_by_kb_and_filename(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
'''获取知识库ID'''
|
||||
|
||||
def find_documents_by_kb_id(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
''''上传文件'''
|
||||
|
||||
async def memory_konwledges_up(
|
||||
kb_id: str,
|
||||
parent_id: str,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
||||
db: Session,
|
||||
current_user: SimpleUser,
|
||||
):
|
||||
# 如果没有提供current_user,则创建一个默认的
|
||||
if current_user is None:
|
||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
||||
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
||||
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
'''添加新块'''
|
||||
|
||||
|
||||
async def create_document_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
|
||||
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
@@ -483,10 +475,7 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||
)
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||
current_user = SimpleUser(user_rag_memory_id)
|
||||
# 检查文档是否已存在
|
||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
else:
|
||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
db.close()
|
||||
@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
# 业务层检查同一工作空间下是否已存在同名配置
|
||||
if params.workspace_id and params.config_name:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
existing = (
|
||||
self.db.query(MemoryConfig)
|
||||
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
|
||||
|
||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||
configs = self._get_workspace_configs(params.workspace_id)
|
||||
@@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||
"scene_name": scene_name, # 新增:场景名称
|
||||
"is_system_default": config.is_default, # 是否为系统默认配置
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
|
||||
@@ -116,24 +116,12 @@ class ModelConfigService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# dashscope 的 omni 模型需要使用 compatible-mode
|
||||
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
|
||||
if not api_base:
|
||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
else:
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
@@ -493,6 +481,9 @@ class ModelApiKeyService:
|
||||
if not model_config:
|
||||
continue
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
@@ -550,8 +541,8 @@ class ModelApiKeyService:
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
capability=data.capability if data.capability is not None else model_config.capability,
|
||||
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
|
||||
capability=data.capability,
|
||||
is_omni=data.is_omni,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
@@ -574,6 +565,10 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
if api_key_data.is_omni is None:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
@@ -616,7 +611,7 @@ class ModelApiKeyService:
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
|
||||
@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
|
||||
from app.core.language_utils import validate_language
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
# 验证语言参数
|
||||
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
if end_user_id:
|
||||
try:
|
||||
# 获取数据库会话并查询用户信息
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
if end_user and end_user.other_name:
|
||||
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||
else:
|
||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
for workspace in workspaces:
|
||||
if workspace.storage_type == 'neo4j':
|
||||
_ensure_default_memory_config(db, workspace)
|
||||
_ensure_default_ontology_scenes(db, workspace)
|
||||
|
||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||
return workspaces
|
||||
@@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults(
|
||||
)
|
||||
|
||||
|
||||
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||
|
||||
Checks whether any is_system_default scene exists for the workspace.
|
||||
If not, runs the DefaultOntologyInitializer to create them.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace to check
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
# 幂等检查:是否已存在系统默认场景
|
||||
existing = db.query(OntologyScene).filter(
|
||||
OntologyScene.workspace_id == workspace.id,
|
||||
OntologyScene.is_system_default.is_(True)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return
|
||||
|
||||
business_logger.info(
|
||||
f"Workspace {workspace.id} missing default ontology scenes, creating them"
|
||||
)
|
||||
|
||||
try:
|
||||
initializer = DefaultOntologyInitializer(db)
|
||||
success, error_msg = initializer.initialize_default_scenes(
|
||||
workspace.id, language="zh"
|
||||
)
|
||||
if success:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景成功"
|
||||
)
|
||||
else:
|
||||
business_logger.warning(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
|
||||
@@ -29,10 +29,10 @@ REDIS_DB=
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
#celery
|
||||
BROKER_URL=
|
||||
RESULT_BACKEND=
|
||||
CELERY_BROKER=
|
||||
CELERY_BACKEND=
|
||||
# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
|
||||
REDIS_DB_CELERY_BROKER=
|
||||
REDIS_DB_CELERY_BACKEND=
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
# Interval in hours for regenerating memory insight and user summary cache
|
||||
|
||||
@@ -440,7 +440,6 @@ export const en = {
|
||||
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
|
||||
publicApiCannotRefreshToken: 'Public API cannot refresh token',
|
||||
refreshTokenNotExist: 'Refresh token does not exist',
|
||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
|
||||
reset: 'Reset',
|
||||
refresh: 'Refresh',
|
||||
return: 'Return',
|
||||
@@ -454,6 +453,7 @@ export const en = {
|
||||
prevStep: 'Previous Step',
|
||||
exportSuccess: 'Export successful',
|
||||
recommend: 'Recommend',
|
||||
default: 'Default',
|
||||
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
||||
imageSquareRequired: 'Please upload a square image',
|
||||
nameInvalid: 'Name cannot start or end with a space',
|
||||
@@ -2616,6 +2616,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
updated_at: 'Updated At',
|
||||
entityTypes: 'Entity Types',
|
||||
|
||||
classSearchPlaceholder: 'Search types',
|
||||
addClass: 'Add Type',
|
||||
class_name: 'Type Name',
|
||||
class_description: 'Type Definition',
|
||||
|
||||
@@ -1020,7 +1020,6 @@ export const zh = {
|
||||
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
|
||||
publicApiCannotRefreshToken: '公共接口不能刷新token',
|
||||
refreshTokenNotExist: '刷新token不存在',
|
||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
|
||||
reset: '重置',
|
||||
refresh: '刷新',
|
||||
return: '返回',
|
||||
@@ -1034,6 +1033,7 @@ export const zh = {
|
||||
prevStep: '上一步',
|
||||
exportSuccess: '导出成功',
|
||||
recommend: '推荐',
|
||||
default: '默认',
|
||||
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
||||
imageSquareRequired: '请上传正方形比例图片',
|
||||
nameInvalid: '不能是空格开头或结尾',
|
||||
@@ -2617,6 +2617,7 @@ export const zh = {
|
||||
updated_at: '更新时间',
|
||||
entityTypes: '实体类型',
|
||||
|
||||
classSearchPlaceholder: '搜索类型',
|
||||
addClass: '添加类型',
|
||||
class_name: '类型名称',
|
||||
class_description: '类型定义',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 16:35:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-02 16:35:15
|
||||
* @Last Modified time: 2026-03-06 10:39:00
|
||||
*/
|
||||
/**
|
||||
* HTTP Request Utility Module
|
||||
@@ -183,7 +183,7 @@ service.interceptors.response.use(
|
||||
msg = msg || i18n.t('common.serverError');
|
||||
break;
|
||||
default:
|
||||
if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') {
|
||||
if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
|
||||
msg = i18n.t(`common.${msg}`)
|
||||
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
||||
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:27:39
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 18:51:20
|
||||
* @Last Modified time: 2026-03-05 17:03:46
|
||||
*/
|
||||
/**
|
||||
* Chat debugging component for application testing
|
||||
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
.then(() => {
|
||||
const message = msg
|
||||
if (!message?.trim()) return
|
||||
// Validate required variables before sending
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
if (chatVariables && chatVariables.length > 0) {
|
||||
const needRequired: string[] = []
|
||||
chatVariables.forEach(vo => {
|
||||
params[vo.name] = vo.value
|
||||
|
||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||
isCanSend = false
|
||||
needRequired.push(vo.name)
|
||||
}
|
||||
})
|
||||
|
||||
if (needRequired.length) {
|
||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||
}
|
||||
}
|
||||
if (!isCanSend) {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
return
|
||||
}
|
||||
|
||||
addUserMessage(message, fileList)
|
||||
setMessage(message)
|
||||
@@ -198,29 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
// Validate required variables before sending
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
if (chatVariables && chatVariables.length > 0) {
|
||||
const needRequired: string[] = []
|
||||
chatVariables.forEach(vo => {
|
||||
params[vo.name] = vo.value
|
||||
|
||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||
isCanSend = false
|
||||
needRequired.push(vo.name)
|
||||
}
|
||||
})
|
||||
|
||||
if (needRequired.length) {
|
||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||
}
|
||||
}
|
||||
if (!isCanSend) {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
return
|
||||
}
|
||||
runCompare(data.app_id, {
|
||||
message,
|
||||
files: fileList.map(file => {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:42
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 18:54:47
|
||||
* @Last Modified time: 2026-03-05 15:09:22
|
||||
*/
|
||||
/**
|
||||
* File Upload Component
|
||||
@@ -206,7 +206,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
||||
*/
|
||||
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
|
||||
newFileList.map(file => {
|
||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type
|
||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
||||
file.type = type
|
||||
})
|
||||
setFileList(newFileList);
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from '@/api/knowledgeBase'
|
||||
import RbModal from '@/components/RbModal'
|
||||
import SliderInput from '@/components/SliderInput'
|
||||
import { stringRegExp } from '@/utils/validator'
|
||||
const { TextArea } = Input;
|
||||
const { confirm } = Modal
|
||||
|
||||
@@ -519,12 +520,16 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('knowledgeBase.createForm.name')}
|
||||
rules={[{ required: true, message: t('knowledgeBase.createForm.nameRequired') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('knowledgeBase.createForm.nameRequired') },
|
||||
{ max: 50 },
|
||||
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('knowledgeBase.createForm.name')} />
|
||||
</Form.Item>
|
||||
)}
|
||||
<Form.Item name="description" label={t('knowledgeBase.createForm.description')}>
|
||||
<Form.Item name="description" label={t('knowledgeBase.createForm.description')} rules={[{ max: 500 }]}>
|
||||
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
|
||||
</Form.Item>
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:33:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:33:15
|
||||
* @Last Modified time: 2026-03-05 16:28:58
|
||||
*/
|
||||
/**
|
||||
* Memory Management Page
|
||||
@@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => {
|
||||
<List.Item key={item.config_id}>
|
||||
<RbCard
|
||||
title={item.config_name}
|
||||
className="rb:relative"
|
||||
>
|
||||
{item.is_system_default &&
|
||||
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
|
||||
{t('common.default')}
|
||||
</div>
|
||||
}
|
||||
<Tooltip title={item.config_desc}>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-4.25">{item.config_desc}</div>
|
||||
</Tooltip>
|
||||
<RbAlert className="rb:mt-3 ">
|
||||
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:33:01
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:33:24
|
||||
* @Last Modified time: 2026-03-05 16:33:53
|
||||
*/
|
||||
/**
|
||||
* Memory management form data type
|
||||
@@ -42,6 +42,7 @@ export interface Memory {
|
||||
workspace_id: string;
|
||||
scene_id: string;
|
||||
scene_name: string;
|
||||
is_system_default: boolean;
|
||||
[key: string]: string | number | boolean;
|
||||
}
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:10:24
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-09 18:02:13
|
||||
* @Last Modified time: 2026-03-06 11:25:59
|
||||
*/
|
||||
import { type FC, type ReactNode } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
@@ -17,7 +17,7 @@ const { Header } = Layout;
|
||||
*/
|
||||
interface ConfigHeaderProps {
|
||||
/** Page title/name */
|
||||
name?: string;
|
||||
name?: string | ReactNode;
|
||||
/** Subtitle content displayed below the title */
|
||||
subTitle?: ReactNode | string;
|
||||
/** Extra content displayed on the right side */
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:10:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-05 10:57:53
|
||||
* @Last Modified time: 2026-03-06 10:56:44
|
||||
*/
|
||||
import { type FC, useState, useRef, type MouseEvent } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
@@ -144,8 +144,13 @@ const Ontology: FC = () => {
|
||||
title={item.scene_name}
|
||||
extra={<Tag>{item.type_num} {t('ontology.typeCount')}</Tag>}
|
||||
onClick={() => handleJump(item)}
|
||||
className="rb:cursor-pointer"
|
||||
className="rb:cursor-pointer rb:relative"
|
||||
>
|
||||
{item.is_system_default &&
|
||||
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
|
||||
{t('common.default')}
|
||||
</div>
|
||||
}
|
||||
<div
|
||||
className="rb:flex rb:gap-2 rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3"
|
||||
>
|
||||
@@ -176,8 +181,8 @@ const Ontology: FC = () => {
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<div className="rb:mt-4 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
||||
<Space size={16}>
|
||||
<div className="rb:mt-4 rb:h-5 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
||||
{!item.is_system_default && <Space size={16}>
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
||||
onClick={(e) => handleEdit(item, e)}
|
||||
@@ -186,7 +191,7 @@ const Ontology: FC = () => {
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||
onClick={(e) => handleDelete(item, e)}
|
||||
></div>
|
||||
</Space>
|
||||
</Space>}
|
||||
</div>
|
||||
</RbCard>
|
||||
)}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:10:20
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-09 17:56:35
|
||||
* @Last Modified time: 2026-03-06 11:26:49
|
||||
*/
|
||||
import { type FC, useEffect, useState, useRef } from 'react'
|
||||
import { useParams } from 'react-router-dom';
|
||||
@@ -17,6 +17,7 @@ import OntologyClassModal from '../components/OntologyClassModal'
|
||||
import SearchInput from '@/components/SearchInput';
|
||||
import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
|
||||
import BodyWrapper from '@/components/Empty/BodyWrapper'
|
||||
import Tag from '@/components/Tag'
|
||||
|
||||
/**
|
||||
* Ontology detail page component
|
||||
@@ -99,19 +100,22 @@ const Detail: FC = () => {
|
||||
return (
|
||||
<>
|
||||
<PageHeader
|
||||
name={data.scene_name}
|
||||
name={<Space>
|
||||
{data.scene_name}
|
||||
<Tag color="warning">{t('common.default')}</Tag>
|
||||
</Space>}
|
||||
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
||||
extra={<Space>
|
||||
extra={data.is_system_default ? undefined : (<Space>
|
||||
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
|
||||
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
|
||||
</Space>}
|
||||
</Space>)}
|
||||
/>
|
||||
|
||||
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
|
||||
<Row gutter={16} className="rb:mb-4">
|
||||
<Col span={6} offset={18}>
|
||||
<SearchInput
|
||||
placeholder={t('ontology.searchPlaceholder')}
|
||||
placeholder={t('ontology.classSearchPlaceholder')}
|
||||
onSearch={(value) => setQuery({ class_name: value })}
|
||||
className="rb:w-full!"
|
||||
/>
|
||||
@@ -123,10 +127,10 @@ const Detail: FC = () => {
|
||||
<Col key={item.class_id} span={6}>
|
||||
<RbCard
|
||||
title={item.class_name}
|
||||
extra={<div
|
||||
extra={data.is_system_default ? undefined : (<div
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(item)}
|
||||
></div>}
|
||||
></div>)}
|
||||
className="rb:bg-transparent!"
|
||||
>
|
||||
<Tooltip title={item.class_description}>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:10:10
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 14:10:10
|
||||
* @Last Modified time: 2026-03-06 10:55:23
|
||||
*/
|
||||
/**
|
||||
* Query parameters for ontology list pagination and filtering
|
||||
@@ -38,6 +38,8 @@ export interface OntologyItem {
|
||||
updated_at: number;
|
||||
/** Total count of classes in the scene */
|
||||
classes_count: number;
|
||||
/** Whether this is the system default configuration */
|
||||
is_system_default: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -92,6 +94,7 @@ export interface OntologyClassData {
|
||||
scene_description: string;
|
||||
/** Array of class items */
|
||||
items: OntologyClassItem[];
|
||||
is_system_default: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user