Merge pull request #486 from SuanmoSuanyangTechnology/release/v0.2.6

Release/v0.2.6
This commit is contained in:
Ke Sun
2026-03-06 12:29:07 +08:00
committed by GitHub
40 changed files with 818 additions and 713 deletions

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1 REDIS_DB=1
# Celery (Using Redis as broker) # Celery (Using Redis as broker)
BROKER_URL=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BROKER=1
RESULT_BACKEND=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BACKEND=2
# JWT Secret Key (Formation method: openssl rand -hex 32) # JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here SECRET_KEY=your-secret-key-here

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1 REDIS_DB=1
# Celery (使用Redis作为broker) # Celery (使用Redis作为broker)
BROKER_URL=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BROKER=1
RESULT_BACKEND=redis://127.0.0.1:6379/0 REDIS_DB_CELERY_BACKEND=2
# JWT密钥 (生成方式: openssl rand -hex 32) # JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here SECRET_KEY=your-secret-key-here

View File

@@ -2,7 +2,6 @@
Cache 缓存模块 Cache 缓存模块
提供各种缓存功能的统一入口 提供各种缓存功能的统一入口
注意隐性记忆和情绪建议已迁移到数据库存储不再使用Redis缓存
""" """
from .memory import InterestMemoryCache from .memory import InterestMemoryCache

View File

@@ -2,7 +2,6 @@
Memory 缓存模块 Memory 缓存模块
提供记忆系统相关的缓存功能 提供记忆系统相关的缓存功能
注意隐性记忆和情绪建议已迁移到数据库存储不再使用Redis缓存
""" """
from .interest_memory import InterestMemoryCache from .interest_memory import InterestMemoryCache

View File

@@ -1,27 +1,54 @@
import os import os
import platform import platform
from datetime import timedelta from datetime import timedelta
from celery.schedules import crontab
from urllib.parse import quote from urllib.parse import quote
from celery import Celery from celery import Celery
from celery.schedules import crontab from celery.schedules import crontab
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization # macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0 # broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB 10 # backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
celery_app = Celery( celery_app = Celery(
"redbear_tasks", "redbear_tasks",
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", broker=_broker_url,
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", backend=_backend_url,
) )
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks # Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks' celery_app.conf.task_default_queue = 'memory_tasks'

View File

@@ -1,28 +1,29 @@
from typing import List, Optional from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success from app.core.response_utils import fail, success
from app.db import get_db from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey from app.models import ModelApiKey
from app.models.user_model import User from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService from app.repositories import knowledge_repository
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv() load_dotenv()
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -37,7 +38,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse) @router.get("/health/status", response_model=ApiResponse)
async def get_health_status( async def get_health_status(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get latest health status written by Celery periodic task Get latest health status written by Celery periodic task
@@ -55,8 +56,9 @@ async def get_health_status(
@router.get("/download_log") @router.get("/download_log")
async def download_log( async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), log_type: str = Query("file", regex="^(file|transmission)$",
current_user: User = Depends(get_current_user) description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
): ):
""" """
Download or stream agent service log file Download or stream agent service log file
@@ -119,10 +121,10 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse) @router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server( async def write_server(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
@@ -161,13 +163,15 @@ async def write_server(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
else: else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
else: else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
@@ -195,10 +199,10 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse) @router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server_async( async def write_server_async(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
@@ -216,7 +220,8 @@ async def write_server_async(
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
@@ -254,9 +259,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def read_server( async def read_server(
user_input: UserInput, user_input: UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Read service endpoint - processes read operations synchronously Read service endpoint - processes read operations synchronously
@@ -292,7 +297,8 @@ async def read_server(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.end_user_id, user_input.end_user_id,
@@ -306,7 +312,8 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
@@ -319,7 +326,7 @@ async def read_server(
db=db db=db
) )
if "信息不足,无法回答" in result['answer']: if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -335,9 +342,10 @@ async def read_server(
@router.post("/file", response_model=ApiResponse) @router.post("/file", response_model=ApiResponse)
async def file_update( async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"), files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"), model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
文件上传接口 - 支持图片识别 文件上传接口 - 支持图片识别
@@ -350,9 +358,6 @@ async def file_update(
Returns: Returns:
文件处理结果 文件处理结果
""" """
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}") api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0] apiConfig: ModelApiKey = config.api_keys[0]
@@ -430,8 +435,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse) @router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result( async def get_read_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async read task Get the status and result of an async read task
@@ -507,8 +512,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse) @router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result( async def get_write_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async write task Get the status and result of an async write task
@@ -584,9 +589,9 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse) @router.post("/status_type", response_model=ApiResponse)
async def status_type( async def status_type(
user_input: Write_UserInput, user_input: Write_UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -629,9 +634,10 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse) @router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api( async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"), only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0 - 如果用户没有当前工作空间,对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try: try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数 # 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats( result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api( async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"), end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"), limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取指定用户的兴趣分布标签 获取指定用户的兴趣分布标签
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse) @router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api( async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -782,9 +783,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config( async def get_end_user_connected_config(
end_user_id: str, end_user_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取终端用户关联的记忆配置 获取终端用户关联的记忆配置

View File

@@ -2,7 +2,7 @@ from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
@@ -85,6 +85,7 @@ def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
@@ -99,7 +100,29 @@ def create_config(
svc = DataConfigService(db) svc = DataConfigService(db)
result = svc.create(payload) result = svc.create(payload)
return success(data=result, msg="创建成功") return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e: except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))

View File

@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
from urllib.parse import quote from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
@@ -289,7 +289,8 @@ async def extract_ontology(
async def create_scene( async def create_scene(
request: SceneCreateRequest, request: SceneCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
): ):
"""创建本体场景 """创建本体场景
@@ -360,8 +361,18 @@ async def create_scene(
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e: except RuntimeError as e:
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) err_str = str(e)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
from app.core.language_utils import get_language_from_header
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
except Exception as e: except Exception as e:
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
@@ -661,7 +672,8 @@ async def get_scenes(
async def create_class( async def create_class(
request: ClassCreateRequest, request: ClassCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
): ):
"""创建本体类型 """创建本体类型
@@ -676,7 +688,7 @@ async def create_class(
ApiResponse: 包含创建的类型信息 ApiResponse: 包含创建的类型信息
""" """
from app.controllers.ontology_secondary_routes import create_class_handler from app.controllers.ontology_secondary_routes import create_class_handler
return await create_class_handler(request, db, current_user) return await create_class_handler(request, db, current_user, x_language_type)
@router.put("/class/{class_id}", response_model=ApiResponse) @router.put("/class/{class_id}", response_model=ApiResponse)

View File

@@ -7,7 +7,7 @@
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional
from fastapi import Depends from fastapi import Depends, Header
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
@@ -58,7 +58,7 @@ async def scenes_handler(
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None,
scene_name: Optional[str] = None, scene_name: Optional[str] = None,
page: Optional[int] = None, page: Optional[int] = None,
page_size: Optional[int] = None, pagesize: Optional[int] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
@@ -71,14 +71,14 @@ async def scenes_handler(
workspace_id: 工作空间ID可选默认当前用户工作空间 workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配) scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效 page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效) pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话 db: 数据库会话
current_user: 当前用户 current_user: 当前用户
""" """
operation = "search" if scene_name else "list" operation = "search" if scene_name else "list"
api_logger.info( api_logger.info(
f"Scene {operation} requested by user {current_user.id}, " f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
) )
try: try:
@@ -105,13 +105,13 @@ async def scenes_handler(
api_logger.warning(f"Invalid page number: {page}") api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1: if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid page_size: {page_size}") api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误 # 如果只提供了page或pagesize中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None): if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页) # 模糊搜索场景(支持分页)
@@ -119,17 +119,15 @@ async def scenes_handler(
total = len(scenes) total = len(scenes)
# 如果提供了分页参数,进行分页处理 # 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None: if page is not None and pagesize is not None:
start_idx = (page - 1) * page_size start_idx = (page - 1) * pagesize
end_idx = start_idx + page_size end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx] scenes = scenes[start_idx:end_idx]
# 构建响应 # 构建响应
items = [] items = []
for scene in scenes: for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0 type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse( items.append(SceneResponse(
@@ -141,17 +139,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id, workspace_id=scene.workspace_id,
created_at=scene.created_at, created_at=scene.created_at,
updated_at=scene.updated_at, updated_at=scene.updated_at,
classes_count=type_num classes_count=type_num,
is_system_default=scene.is_system_default
)) ))
# 构建响应(包含分页信息) # 构建响应(包含分页信息)
if page is not None and page_size is not None: if page is not None and pagesize is not None:
# 计算是否有下一页 hasnext = (page * pagesize) < total
hasnext = (page * page_size) < total
pagination_info = PaginationInfo( pagination_info = PaginationInfo(
page=page, page=page,
pagesize=page_size, pagesize=pagesize,
total=total, total=total,
hasnext=hasnext hasnext=hasnext
) )
@@ -165,28 +162,25 @@ async def scenes_handler(
) )
else: else:
# 获取所有场景(支持分页) # 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1: if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}") api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1: if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid page_size: {page_size}") api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误 # 如果只提供了page或pagesize中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None): if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size) scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应 # 构建响应
items = [] items = []
for scene in scenes: for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0 type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse( items.append(SceneResponse(
@@ -198,17 +192,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id, workspace_id=scene.workspace_id,
created_at=scene.created_at, created_at=scene.created_at,
updated_at=scene.updated_at, updated_at=scene.updated_at,
classes_count=type_num classes_count=type_num,
is_system_default=scene.is_system_default
)) ))
# 构建响应(包含分页信息) # 构建响应(包含分页信息)
if page is not None and page_size is not None: if page is not None and pagesize is not None:
# 计算是否有下一页 hasnext = (page * pagesize) < total
hasnext = (page * page_size) < total
pagination_info = PaginationInfo( pagination_info = PaginationInfo(
page=page, page=page,
pagesize=page_size, pagesize=pagesize,
total=total, total=total,
hasnext=hasnext hasnext=hasnext
) )
@@ -238,7 +231,8 @@ async def scenes_handler(
async def create_class_handler( async def create_class_handler(
request: ClassCreateRequest, request: ClassCreateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
): ):
"""创建本体类型(统一使用列表形式,支持单个或批量)""" """创建本体类型(统一使用列表形式,支持单个或批量)"""
@@ -271,8 +265,11 @@ async def create_class_handler(
] ]
if count == 1: if count == 1:
# 单个创建 # 单个创建 - 先检查重名
class_data = classes_data[0] class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class( ontology_class = service.create_class(
scene_id=request.scene_id, scene_id=request.scene_id,
class_name=class_data["class_name"], class_name=class_data["class_name"],
@@ -330,12 +327,36 @@ async def create_class_handler(
return success(data=response.model_dump(mode='json'), msg="批量创建完成") return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e: except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}") err_str = str(e)
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e: except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) err_str = str(e)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e: except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
@@ -615,6 +636,7 @@ async def classes_handler(
scene_id=scene_uuid, scene_id=scene_uuid,
scene_name=scene.scene_name, scene_name=scene.scene_name,
scene_description=scene.scene_description, scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items items=items
) )

View File

@@ -190,8 +190,10 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal) # Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) # 详见 docs/celery-env-bug-report.md
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
# SMTP Email Configuration # SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com") SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")

View File

@@ -1,10 +1,10 @@
import os
import json import json
import os
import time import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
structured = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
response_content = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, response_content = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 ===== # ===== 第三方库 =====
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import ( from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync, create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool, create_time_retrieval_tool,
extract_tool_message_content, extract_tool_message_content,
) )
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state): async def rag_config(state):
@@ -50,10 +45,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState: async def llm_infomation(state: ReadState) -> ReadState:
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str): elif isinstance(item, str):
text_parts.append(item) text_parts.append(item)
return '\n'.join(text_parts).strip() return '\n'.join(text_parts).strip()
except Exception as e: except Exception as e:
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve_nodes(state: ReadState) -> ReadState:
''' '''
模型信息 模型信息
''' '''
problem_extension=state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id=state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original = state.get('data', '')
problem_list=[] problem_list = []
for key,values in problem_extension.items(): for key, values in problem_extension.items():
for data in values: for data in values:
problem_list.append(data) problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题
async def process_question_nodes(idx, question): async def process_question_nodes(idx, question):
try: try:
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
send_verify = [] send_verify = []
for i, j in zip(keys, val, strict=False): for i, j in zip(keys, val, strict=False):
if j!=['']: if j != ['']:
send_verify.append({ send_verify.append({
"Query_small": i, "Query_small": i,
"Answer_Small": j "Answer_Small": j
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
} }
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases} return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id # 从state中获取end_user_id
import time import time
start=time.time() start = time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器 with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
return await llm_infomation(state) return await llm_infomation(state)
llm_config = await get_llm_info() llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0] api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key api_key = api_key_obj.api_key
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
) )
time_retrieval_tool = create_time_retrieval_tool(end_user_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True } search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发 async with SEMAPHORE: # 限制并发
try: try:
if storage_type == "rag" and user_rag_memory_id: if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else: else:
cleaned_query = question cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke # 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4) # json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases} return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os import os
import time import time
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin): class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类""" """总结节点服务类"""
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def rag_config(state): async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = { kb_config = {
@@ -51,10 +51,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -62,20 +64,23 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
""" """
增强的summary_llm函数包含更好的错误处理和数据验证 增强的summary_llm函数包含更好的错误处理和数据验证
""" """
@@ -99,13 +104,14 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
) )
try: try:
# 使用优化的LLM服务进行结构化输出 # 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await summary_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=response_model, system_prompt=system_prompt,
fallback_value=None response_model=response_model,
) fallback_value=None
)
# 验证结构化响应 # 验证结构化响应
if structured is None: if structured is None:
logger.warning("LLM返回None使用默认回答") logger.warning("LLM返回None使用默认回答")
@@ -157,7 +163,8 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
logger.error(f"Fallback也失败: {fallback_error}") logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答" return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功") logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'') async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
data=state.get("data", '') storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = { input_summary = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
} }
retrieve={ retrieve = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id, "user_rag_memory_id": user_rag_memory_id,
"_intermediate": { "_intermediate": {
"type": "retrieval_summary", "type": "retrieval_summary",
"title":"快速检索", "title": "快速检索",
"summary": aimessages, "summary": aimessages,
"query": data, "query": data,
"storage_type": storage_type, "storage_type": storage_type,
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
} }
} }
return input_summary,retrieve return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState: async def Input_Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
storage_type=state.get("storage_type",'') storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
data=state.get("data", '') data = state.get("data", '')
end_user_id=state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history(state)
search_params = { search_params = {
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
} }
try: try:
if storage_type!="rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else: else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e: except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, [] retrieve_info, question, raw_results = "", data, []
try: try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0] summary = summary_result[0]
except Exception as e: except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True ) logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary= { summary = {
"status": "fail", "status": "fail",
"summary_result": "信息不足,无法回答", "summary_result": "信息不足,无法回答",
"storage_type": storage_type, "storage_type": storage_type,
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
except Exception: except Exception:
duration = 0.0 duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary":summary} return {"summary": summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '') async def Retrieve_Summary(state: ReadState) -> ReadState:
history = await summary_history( state) retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json import json
with open("检索.json","w",encoding='utf-8') as f: with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", []) retrieve = retrieve.get("Expansion_issue", [])
start=time.time() start = time.time()
retrieve_info_str=[] retrieve_info_str = []
for data in retrieve: for data in retrieve:
if data=='': if data == '':
retrieve_info_str='' retrieve_info_str = ''
else: else:
for key, value in data.items(): for key, value in data.items():
if key=='Answer_Small': if key == 'Answer_Small':
for i in value: for i in value:
retrieve_info_str.append(i) retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str, aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -290,29 +302,29 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary(state: ReadState)-> ReadState: async def Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
query = state.get("data", '') query = state.get("data", '')
verify=state.get("verify", '') verify = state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '') verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str='' retrieve_info_str = ''
for data in verify_expansion_issue: for data in verify_expansion_issue:
for key, value in data.items(): for key, value in data.items():
if key=='answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str+=i+'\n' retrieve_info_str += i + '\n'
history=await summary_history(state) history = await summary_history(state)
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages=await summary_llm(state,history,data, aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2','summary',SummaryResponse,0) 'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '') async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state) history = await summary_history(state)
query = state.get("data", '') query = state.get("data", '')
verify = state.get("verify", '') verify = state.get("verify", '')
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages = await summary_llm(state, history, data, aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= { result = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
return {"summary":result} return {"summary": result}

View File

@@ -1,8 +1,9 @@
import asyncio
import os import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin): class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类""" """验证节点服务类"""
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
verification_service = VerificationNodeService() verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式""" """处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
} }
} }
return Verify_result return Verify_result
async def Verify(state: ReadState): async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
@@ -100,23 +106,24 @@ async def Verify(state: ReadState):
try: try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待 # 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for( with get_db_context() as db_session:
verification_service.call_llm_structured( structured = await asyncio.wait_for(
state=state, verification_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=VerificationResult, system_prompt=system_prompt,
fallback_value={ response_model=VerificationResult,
"query": content, fallback_value={
"history": history if isinstance(history, list) else [], "query": content,
"expansion_issue": [], "history": history if isinstance(history, list) else [],
"split_result": "failed", "expansion_issue": [],
"reason": "验证失败或超时" "split_result": "failed",
} "reason": "验证失败或超时"
), }
timeout=150.0 # 150秒超时 ),
) timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}") logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值") logger.error("Verify: LLM 调用超时150秒使用 fallback 值")

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END from langgraph.constants import START, END
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db from app.db import get_db
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
) )
@asynccontextmanager @asynccontextmanager
async def make_read_graph(): async def make_read_graph():
"""创建并返回 LangGraph 工作流""" """创建并返回 LangGraph 工作流"""
@@ -62,7 +60,6 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----''' '''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
@@ -76,6 +73,7 @@ async def make_read_graph():
finally: finally:
print("工作流创建完成") print("工作流创建完成")
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
@@ -92,13 +90,15 @@ async def main():
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
import time import time
start=time.time() start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} "end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
@@ -141,7 +141,6 @@ async def main():
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n) _intermediate_outputs.append(verify_n)
# Summary 节点 # Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None) summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}: if summary_n and summary_n != [] and summary_n != {}:
@@ -165,13 +164,16 @@ async def main():
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally:
db_session.close()
end=time.time() end = time.time()
print(100*'y') print(100 * 'y')
print(f"总耗时: {end-start}s") print(f"总耗时: {end - start}s")
print(100*'y') print(100 * 'y')
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db from app.db import get_db_context
from app.models import AppRelease from app.models import AppRelease
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑 """准备 Agent公共逻辑
Args: Args:
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
if not agent_id: if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db()) with get_db_context() as db:
release = db.query(AppRelease).filter( release = db.query(AppRelease).filter(
AppRelease.id == agent_id AppRelease.id == agent_id
).first() ).first()
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = AgentRunService(db)
return draft_service, release, message
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行 """非流式执行
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
Returns: Returns:
状态更新字典 状态更新字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent非流式 # 执行 Agent非流式
result = await draft_service.run( result = await draft_service.run(
agent_config=release.config, agent_config=release.config,
model_config=None, model_config=None,
message=message, message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"), workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"), user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars() variables=variable_pool.get_all_conversation_vars()
) )
response = result.get("response", "") response = result.get("response", "")
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
Yields: Yields:
流式事件字典 流式事件字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式 # 执行 Agent流式
async for chunk in draft_service.run_stream( async for chunk in draft_service.run_stream(
agent_config=release.config, agent_config=release.config,
model_config=None, model_config=None,
message=message, message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"), workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"), user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars() variables=variable_pool.get_all_conversation_vars()
): ):
# 提取内容 # 提取内容
content = chunk.get("content", "") content = chunk.get("content", "")
full_response += content full_response += content
# 流式返回每个 chunk # 流式返回每个 chunk
yield { yield {
"type": "chunk", "type": "chunk",
"node_id": self.node_id, "node_id": self.node_id,
"content": content, "content": content,
"full_content": full_response, "full_content": full_response,
"meta_data": chunk.get("meta_data", {}) "meta_data": chunk.get("meta_data", {})
} }
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")

View File

@@ -374,7 +374,7 @@ class OntologySceneRepository:
count = self.db.query(OntologyScene).filter( count = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id, OntologyScene.scene_id == scene_id,
OntologyScene.workspace_id == workspace_id (OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
).count() ).count()
is_owner = count > 0 is_owner = count > 0

View File

@@ -116,8 +116,8 @@ class ModelApiKeyBase(BaseModel):
provider: ModelProvider = Field(..., description="API Key提供商") provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500) api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: List[str] = Field(default_factory=list, description="模型能力列表") capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型") is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活") is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10) priority: str = Field("1", description="优先级", max_length=10)

View File

@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
classes_count: int = Field(0, description="类型数量") classes_count: int = Field(0, description="类型数量")
is_system_default: bool = Field(False, description="是否为系统默认场景")
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
scene_id: UUID = Field(..., description="所属场景ID") scene_id: UUID = Field(..., description="所属场景ID")
scene_name: str = Field(..., description="场景名称") scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述") scene_description: Optional[str] = Field(None, description="场景描述")
is_system_default: bool = Field(False, description="是否为系统默认场景")
items: List[ClassResponse] = Field(..., description="类型列表") items: List[ClassResponse] = Field(..., description="类型列表")

View File

@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
""" """
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}") logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try: try:
from app.db import get_db with get_db_context() as db:
db = next(get_db())
try:
memory_content = asyncio.run( memory_content = asyncio.run(
MemoryAgentService().read_memory( MemoryAgentService().read_memory(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
logger.info(f"读取任务状态:{status}") logger.info(f"读取任务状态:{status}")
if memory_content: if memory_content:
memory_content = memory_content['answer'] memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}') logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
""" """
import json import json
import os import os
import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results, reorder_output_results,
) )
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)}) duration=duration, details={"message_length": len(message)})
return context return context
else: else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool: def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event""" """Extract tool call information from event"""
last_message = event["messages"][-1] last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content'])) langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant': elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content'])) langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-') print(100 * '-')
print(langchain_messages) print(langchain_messages)
print(100*'-') print(100 * '-')
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result') contents = massages.get('write_result')
# Convert messages back to string for logging # Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
async def read_memory( async def read_memory(
self, self,
end_user_id: str, end_user_id: str,
message: str, message: str,
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: Optional[uuid.UUID]|int, config_id: Optional[uuid.UUID] | int,
db: Session, db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str) -> Dict: user_rag_memory_id: str) -> Dict:
""" """
Process read operation with config_id Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message= message ori_message = message
# Resolve config_id and workspace_id # Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided # Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -576,7 +574,8 @@ class MemoryAgentService:
raw_results = intermediate.get('raw_results', {}) raw_results = intermediate.get('raw_results', {})
try: try:
reranked_results = raw_results.get('reranked_results', []) reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])] statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception: except Exception:
statements = [] statements = []
@@ -602,7 +601,8 @@ class MemoryAgentService:
) )
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else: else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error: except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误 # 保存失败不应该影响主流程,只记录错误
@@ -610,7 +610,8 @@ class MemoryAgentService:
# Log successful operation # Log successful operation
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
""" """
Get standardized message list from user input. Get standardized message list from user input.
@@ -665,7 +665,8 @@ class MemoryAgentService:
for idx, msg in enumerate(user_input.messages): for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict): if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg: if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
@@ -673,7 +674,8 @@ class MemoryAgentService:
if 'content' not in msg: if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']: if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
@@ -687,11 +689,11 @@ class MemoryAgentService:
return user_input.messages return user_input.messages
async def classify_message_type( async def classify_message_type(
self, self,
message: str, message: str,
config_id: UUID, config_id: UUID,
db: Session, db: Session,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> Dict: ) -> Dict:
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id) status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}") logger.debug(f"Message type: {status}")
return status return status
async def generate_summary_from_retrieve( async def generate_summary_from_retrieve(
self, self,
end_user_id: str, end_user_id: str,
retrieve_info: str, retrieve_info: str,
history: List[Dict], history: List[Dict],
query: str, query: str,
config_id: str, config_id: str,
db: Session db: Session
) -> str: ) -> str:
""" """
基于检索信息、历史对话和查询生成最终答案 基于检索信息、历史对话和查询生成最终答案
@@ -805,13 +808,12 @@ class MemoryAgentService:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True) logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。" return "信息不足,无法回答。"
async def get_knowledge_type_stats( async def get_knowledge_type_stats(
self, self,
end_user_id: Optional[str] = None, db: Session,
only_active: bool = True, end_user_id: Optional[str] = None,
current_workspace_id: Optional[uuid.UUID] = None, only_active: bool = True,
db: Session = None current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
统计知识库类型分布,包含: 统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型 # 1. 统计 PostgreSQL 中的知识库类型
try: try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0 # 初始化所有标准类型为 0
for kb_type in KnowledgeType: for kb_type in KnowledgeType:
result[kb_type.value] = 0 result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory # 3. 计算知识库类型总和(不包括 memory
result["total"] = ( result["total"] = (
result.get("General", 0) + result.get("General", 0) +
result.get("Web", 0) + result.get("Web", 0) +
result.get("Third-party", 0) + result.get("Third-party", 0) +
result.get("Folder", 0) result.get("Folder", 0)
) )
return result return result
async def get_interest_distribution_by_user( async def get_interest_distribution_by_user(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 5, limit: int = 5,
language: str = "zh" language: str = "zh"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取指定用户的兴趣分布标签。 获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}") logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}") raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile( async def get_user_profile(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None, current_user_id: Optional[str] = None,
llm_id: Optional[str] = None, llm_id: Optional[str] = None,
db: Session = None db: Session = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构 # 定义标签提取的结构
class UserTags(BaseModel): class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友") tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [ messages = [
{ {
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时 ValueError: 当终端用户不存在或应用未发布时
""" """
import json as json_module import json as json_module
import uuid
from sqlalchemy import select from sqlalchemy import select
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id) "workspace_id": str(app.workspace_id)
} }
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result return result

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件 # 修改 memory_konwledges_server.py 文件
import asyncio
import os import os
import re
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db_context
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.models.document_model import Document from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试 # 创建一个简单的用户类用于测试
api_logger = get_api_logger() api_logger = get_api_logger()
class ChunkCreate(BaseModel): class ChunkCreate(BaseModel):
content: str content: str
class SimpleUser: class SimpleUser:
def __init__(self, user_id: str): def __init__(self, user_id: str):
# 确保ID是UUID类型 # 确保ID是UUID类型
self.id = user_id self.id = user_id
self.username = user_id self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
""" """
解析指定文档 解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise raise
'''获取块ID'''
async def get_document_chunks( async def get_document_chunks(
kb_id: uuid.UUID, kb_id: uuid.UUID,
document_id: uuid.UUID, document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功") return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename( def find_document_id_by_kb_and_filename(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e: except Exception as e:
return None return None
'''获取知识库ID'''
def find_documents_by_kb_id( def find_documents_by_kb_id(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e: except Exception as e:
return [] return []
''''上传文件'''
async def memory_konwledges_up( async def memory_konwledges_up(
kb_id: str, kb_id: str,
parent_id: str, parent_id: str,
create_data: file_schema.CustomTextFileCreate, create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db), db: Session,
current_user: SimpleUser = None, # 修改为SimpleUser current_user: SimpleUser,
): ):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8') content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes) file_size = len(content_bytes)
print(f"file size: {file_size} byte") print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk( async def create_document_chunk(
kb_id: uuid.UUID, kb_id: uuid.UUID,
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功") return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id): async def write_rag(end_user_id, message, user_rag_memory_id):
""" """
将消息写入 RAG 知识库 将消息写入 RAG 知识库
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}" detail=f"知识库ID格式无效: {user_rag_memory_id}"
) )
db_gen = get_db() with get_db_context() as db:
db = next(db_gen)
try:
create_data = CustomTextFileCreate(title=end_user_id, content=message) create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id) current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在 # 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document) print('======', document)
api_logger.info(f"查找文档结果: document_id={document}") api_logger.info(f"查找文档结果: document_id={document}")
if document is not None: if document is not None:
# 文档已存在,直接添加新块 # 文档已存在,直接添加新块
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else: else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Create --- # --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig
existing = (
self.db.query(MemoryConfig)
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
.first()
)
if existing:
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
# 如果workspace_id存在且模型字段未全部指定则自动获取 # 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]): if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id) configs = self._get_workspace_configs(params.workspace_id)
@@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"apply_id": config.apply_id, "apply_id": config.apply_id,
"scene_id": str(config.scene_id) if config.scene_id else None, "scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称 "scene_name": scene_name, # 新增:场景名称
"is_system_default": config.is_default, # 是否为系统默认配置
"llm_id": config.llm_id, "llm_id": config.llm_id,
"embedding_id": config.embedding_id, "embedding_id": config.embedding_id,
"rerank_id": config.rerank_id, "rerank_id": config.rerank_id,

View File

@@ -116,27 +116,15 @@ class ModelConfigService:
try: try:
start_time = time.time() start_time = time.time()
# dashscope 的 omni 模型需要使用 compatible-mode model_config = RedBearModelConfig(
if provider.lower() == ModelProvider.DASHSCOPE and is_omni: model_name=model_name,
if not api_base: provider=provider,
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" api_key=api_key,
model_config = RedBearModelConfig( base_url=api_base,
model_name=model_name, is_omni=is_omni,
provider=ModelProvider.OPENAI, temperature=0.7,
api_key=api_key, max_tokens=100
base_url=api_base, )
temperature=0.7,
max_tokens=100
)
else:
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式 # 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower() model_type_lower = model_type.lower()
@@ -493,6 +481,9 @@ class ModelApiKeyService:
if not model_config: if not model_config:
continue continue
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name # 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name model_name = model_config.model_base.name if model_config.model_base else model_config.name
@@ -550,8 +541,8 @@ class ModelApiKeyService:
provider=data.provider, provider=data.provider,
api_key=data.api_key, api_key=data.api_key,
api_base=data.api_base, api_base=data.api_base,
capability=data.capability if data.capability is not None else model_config.capability, capability=data.capability,
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni, is_omni=data.is_omni,
config=data.config, config=data.config,
is_active=data.is_active, is_active=data.is_active,
priority=data.priority priority=data.priority
@@ -574,6 +565,10 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if api_key_data.is_omni is None:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id # 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
@@ -616,7 +611,7 @@ class ModelApiKeyService:
api_base=api_key_data.api_base, api_base=api_key_data.api_base,
model_type=model_config.type, model_type=model_config.type,
test_message="Hello", test_message="Hello",
is_omni=model_config.is_omni is_omni=api_key_data.is_omni
) )
if not validation_result["valid"]: if not validation_result["valid"]:
raise BusinessException( raise BusinessException(

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.memory_base_service import MemoryBaseService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数 # 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id: if end_user_id:
try: try:
# 获取数据库会话并查询用户信息 # 获取数据库会话并查询用户信息
db = next(get_db()) with get_db_context() as db:
try:
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id)) end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name: if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else: else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e: except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")

View File

@@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
for workspace in workspaces: for workspace in workspaces:
if workspace.storage_type == 'neo4j': if workspace.storage_type == 'neo4j':
_ensure_default_memory_config(db, workspace) _ensure_default_memory_config(db, workspace)
_ensure_default_ontology_scenes(db, workspace)
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}") business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
return workspaces return workspaces
@@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults(
) )
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
"""Ensure a workspace has default ontology scenes, creating them if missing.
Checks whether any is_system_default scene exists for the workspace.
If not, runs the DefaultOntologyInitializer to create them.
Args:
db: Database session
workspace: The workspace to check
"""
from app.models.ontology_scene import OntologyScene
# 幂等检查:是否已存在系统默认场景
existing = db.query(OntologyScene).filter(
OntologyScene.workspace_id == workspace.id,
OntologyScene.is_system_default.is_(True)
).first()
if existing:
return
business_logger.info(
f"Workspace {workspace.id} missing default ontology scenes, creating them"
)
try:
initializer = DefaultOntologyInitializer(db)
success, error_msg = initializer.initialize_default_scenes(
workspace.id, language="zh"
)
if success:
db.commit()
business_logger.info(
f"为工作空间 {workspace.id} 补建默认本体场景成功"
)
else:
business_logger.warning(
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
)
except Exception as e:
db.rollback()
business_logger.error(
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
)
def _create_default_memory_config( def _create_default_memory_config(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,

View File

@@ -29,10 +29,10 @@ REDIS_DB=
REDIS_PASSWORD=password REDIS_PASSWORD=password
#celery #celery
BROKER_URL= # NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
RESULT_BACKEND= # 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
CELERY_BROKER= REDIS_DB_CELERY_BROKER=
CELERY_BACKEND= REDIS_DB_CELERY_BACKEND=
# Memory Cache Regeneration Configuration # Memory Cache Regeneration Configuration
# Interval in hours for regenerating memory insight and user summary cache # Interval in hours for regenerating memory insight and user summary cache

View File

@@ -440,7 +440,6 @@ export const en = {
logoutApiCannotRefreshToken: 'Logout API cannot refresh token', logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
publicApiCannotRefreshToken: 'Public API cannot refresh token', publicApiCannotRefreshToken: 'Public API cannot refresh token',
refreshTokenNotExist: 'Refresh token does not exist', refreshTokenNotExist: 'Refresh token does not exist',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
reset: 'Reset', reset: 'Reset',
refresh: 'Refresh', refresh: 'Refresh',
return: 'Return', return: 'Return',
@@ -454,6 +453,7 @@ export const en = {
prevStep: 'Previous Step', prevStep: 'Previous Step',
exportSuccess: 'Export successful', exportSuccess: 'Export successful',
recommend: 'Recommend', recommend: 'Recommend',
default: 'Default',
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
imageSquareRequired: 'Please upload a square image', imageSquareRequired: 'Please upload a square image',
nameInvalid: 'Name cannot start or end with a space', nameInvalid: 'Name cannot start or end with a space',
@@ -2616,6 +2616,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
updated_at: 'Updated At', updated_at: 'Updated At',
entityTypes: 'Entity Types', entityTypes: 'Entity Types',
classSearchPlaceholder: 'Search types',
addClass: 'Add Type', addClass: 'Add Type',
class_name: 'Type Name', class_name: 'Type Name',
class_description: 'Type Definition', class_description: 'Type Definition',

View File

@@ -1020,7 +1020,6 @@ export const zh = {
logoutApiCannotRefreshToken: '退出登录接口不能刷新token', logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
publicApiCannotRefreshToken: '公共接口不能刷新token', publicApiCannotRefreshToken: '公共接口不能刷新token',
refreshTokenNotExist: '刷新token不存在', refreshTokenNotExist: '刷新token不存在',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
reset: '重置', reset: '重置',
refresh: '刷新', refresh: '刷新',
return: '返回', return: '返回',
@@ -1034,6 +1033,7 @@ export const zh = {
prevStep: '上一步', prevStep: '上一步',
exportSuccess: '导出成功', exportSuccess: '导出成功',
recommend: '推荐', recommend: '推荐',
default: '默认',
logoTip: `支持图片格式JPG、PNG\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, logoTip: `支持图片格式JPG、PNG\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
imageSquareRequired: '请上传正方形比例图片', imageSquareRequired: '请上传正方形比例图片',
nameInvalid: '不能是空格开头或结尾', nameInvalid: '不能是空格开头或结尾',
@@ -2617,6 +2617,7 @@ export const zh = {
updated_at: '更新时间', updated_at: '更新时间',
entityTypes: '实体类型', entityTypes: '实体类型',
classSearchPlaceholder: '搜索类型',
addClass: '添加类型', addClass: '添加类型',
class_name: '类型名称', class_name: '类型名称',
class_description: '类型定义', class_description: '类型定义',

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-02 16:35:15 * @Date: 2026-02-02 16:35:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-02 16:35:15 * @Last Modified time: 2026-03-06 10:39:00
*/ */
/** /**
* HTTP Request Utility Module * HTTP Request Utility Module
@@ -183,7 +183,7 @@ service.interceptors.response.use(
msg = msg || i18n.t('common.serverError'); msg = msg || i18n.t('common.serverError');
break; break;
default: default:
if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') { if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
msg = i18n.t(`common.${msg}`) msg = i18n.t(`common.${msg}`)
} else if (!msg && Array.isArray(error.response?.data?.detail)) { } else if (!msg && Array.isArray(error.response?.data?.detail)) {
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';') msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 16:27:39 * @Date: 2026-02-03 16:27:39
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 18:51:20 * @Last Modified time: 2026-03-05 17:03:46
*/ */
/** /**
* Chat debugging component for application testing * Chat debugging component for application testing
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
.then(() => { .then(() => {
const message = msg const message = msg
if (!message?.trim()) return if (!message?.trim()) return
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
return
}
addUserMessage(message, fileList) addUserMessage(message, fileList)
setMessage(message) setMessage(message)
@@ -198,29 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
}; };
setTimeout(() => { setTimeout(() => {
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
return
}
runCompare(data.app_id, { runCompare(data.app_id, {
message, message,
files: fileList.map(file => { files: fileList.map(file => {

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-06 21:09:42 * @Date: 2026-02-06 21:09:42
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 18:54:47 * @Last Modified time: 2026-03-05 15:09:22
*/ */
/** /**
* File Upload Component * File Upload Component
@@ -206,7 +206,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
*/ */
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
newFileList.map(file => { newFileList.map(file => {
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
file.type = type file.type = type
}) })
setFileList(newFileList); setFileList(newFileList);

View File

@@ -15,6 +15,7 @@ import {
} from '@/api/knowledgeBase' } from '@/api/knowledgeBase'
import RbModal from '@/components/RbModal' import RbModal from '@/components/RbModal'
import SliderInput from '@/components/SliderInput' import SliderInput from '@/components/SliderInput'
import { stringRegExp } from '@/utils/validator'
const { TextArea } = Input; const { TextArea } = Input;
const { confirm } = Modal const { confirm } = Modal
@@ -519,12 +520,16 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
<Form.Item <Form.Item
name="name" name="name"
label={t('knowledgeBase.createForm.name')} label={t('knowledgeBase.createForm.name')}
rules={[{ required: true, message: t('knowledgeBase.createForm.nameRequired') }]} rules={[
{ required: true, message: t('knowledgeBase.createForm.nameRequired') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
> >
<Input placeholder={t('knowledgeBase.createForm.name')} /> <Input placeholder={t('knowledgeBase.createForm.name')} />
</Form.Item> </Form.Item>
)} )}
<Form.Item name="description" label={t('knowledgeBase.createForm.description')}> <Form.Item name="description" label={t('knowledgeBase.createForm.description')} rules={[{ max: 500 }]}>
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} /> <TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
</Form.Item> </Form.Item>

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:33:15 * @Date: 2026-02-03 17:33:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:33:15 * @Last Modified time: 2026-03-05 16:28:58
*/ */
/** /**
* Memory Management Page * Memory Management Page
@@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => {
<List.Item key={item.config_id}> <List.Item key={item.config_id}>
<RbCard <RbCard
title={item.config_name} title={item.config_name}
className="rb:relative"
> >
{item.is_system_default &&
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
{t('common.default')}
</div>
}
<Tooltip title={item.config_desc}> <Tooltip title={item.config_desc}>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div> <div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-4.25">{item.config_desc}</div>
</Tooltip> </Tooltip>
<RbAlert className="rb:mt-3 "> <RbAlert className="rb:mt-3 ">
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}> <div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:33:01 * @Date: 2026-02-03 17:33:01
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:33:24 * @Last Modified time: 2026-03-05 16:33:53
*/ */
/** /**
* Memory management form data type * Memory management form data type
@@ -42,6 +42,7 @@ export interface Memory {
workspace_id: string; workspace_id: string;
scene_id: string; scene_id: string;
scene_name: string; scene_name: string;
is_system_default: boolean;
[key: string]: string | number | boolean; [key: string]: string | number | boolean;
} }
/** /**

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:24 * @Date: 2026-02-03 14:10:24
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 18:02:13 * @Last Modified time: 2026-03-06 11:25:59
*/ */
import { type FC, type ReactNode } from 'react'; import { type FC, type ReactNode } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
@@ -17,7 +17,7 @@ const { Header } = Layout;
*/ */
interface ConfigHeaderProps { interface ConfigHeaderProps {
/** Page title/name */ /** Page title/name */
name?: string; name?: string | ReactNode;
/** Subtitle content displayed below the title */ /** Subtitle content displayed below the title */
subTitle?: ReactNode | string; subTitle?: ReactNode | string;
/** Extra content displayed on the right side */ /** Extra content displayed on the right side */

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:15 * @Date: 2026-02-03 14:10:15
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 10:57:53 * @Last Modified time: 2026-03-06 10:56:44
*/ */
import { type FC, useState, useRef, type MouseEvent } from 'react'; import { type FC, useState, useRef, type MouseEvent } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
@@ -144,8 +144,13 @@ const Ontology: FC = () => {
title={item.scene_name} title={item.scene_name}
extra={<Tag>{item.type_num} {t('ontology.typeCount')}</Tag>} extra={<Tag>{item.type_num} {t('ontology.typeCount')}</Tag>}
onClick={() => handleJump(item)} onClick={() => handleJump(item)}
className="rb:cursor-pointer" className="rb:cursor-pointer rb:relative"
> >
{item.is_system_default &&
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
{t('common.default')}
</div>
}
<div <div
className="rb:flex rb:gap-2 rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3" className="rb:flex rb:gap-2 rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3"
> >
@@ -176,8 +181,8 @@ const Ontology: FC = () => {
)} )}
</Flex> </Flex>
<div className="rb:mt-4 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end"> <div className="rb:mt-4 rb:h-5 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
<Space size={16}> {!item.is_system_default && <Space size={16}>
<div <div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
onClick={(e) => handleEdit(item, e)} onClick={(e) => handleEdit(item, e)}
@@ -186,7 +191,7 @@ const Ontology: FC = () => {
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={(e) => handleDelete(item, e)} onClick={(e) => handleDelete(item, e)}
></div> ></div>
</Space> </Space>}
</div> </div>
</RbCard> </RbCard>
)} )}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:20 * @Date: 2026-02-03 14:10:20
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 17:56:35 * @Last Modified time: 2026-03-06 11:26:49
*/ */
import { type FC, useEffect, useState, useRef } from 'react' import { type FC, useEffect, useState, useRef } from 'react'
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
@@ -17,6 +17,7 @@ import OntologyClassModal from '../components/OntologyClassModal'
import SearchInput from '@/components/SearchInput'; import SearchInput from '@/components/SearchInput';
import OntologyClassExtractModal from '../components/OntologyClassExtractModal' import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
import BodyWrapper from '@/components/Empty/BodyWrapper' import BodyWrapper from '@/components/Empty/BodyWrapper'
import Tag from '@/components/Tag'
/** /**
* Ontology detail page component * Ontology detail page component
@@ -99,19 +100,22 @@ const Detail: FC = () => {
return ( return (
<> <>
<PageHeader <PageHeader
name={data.scene_name} name={<Space>
{data.scene_name}
<Tag color="warning">{t('common.default')}</Tag>
</Space>}
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>} subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
extra={<Space> extra={data.is_system_default ? undefined : (<Space>
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button> <Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button> <Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
</Space>} </Space>)}
/> />
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4"> <div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
<Row gutter={16} className="rb:mb-4"> <Row gutter={16} className="rb:mb-4">
<Col span={6} offset={18}> <Col span={6} offset={18}>
<SearchInput <SearchInput
placeholder={t('ontology.searchPlaceholder')} placeholder={t('ontology.classSearchPlaceholder')}
onSearch={(value) => setQuery({ class_name: value })} onSearch={(value) => setQuery({ class_name: value })}
className="rb:w-full!" className="rb:w-full!"
/> />
@@ -123,10 +127,10 @@ const Detail: FC = () => {
<Col key={item.class_id} span={6}> <Col key={item.class_id} span={6}>
<RbCard <RbCard
title={item.class_name} title={item.class_name}
extra={<div extra={data.is_system_default ? undefined : (<div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]" className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={() => handleDelete(item)} onClick={() => handleDelete(item)}
></div>} ></div>)}
className="rb:bg-transparent!" className="rb:bg-transparent!"
> >
<Tooltip title={item.class_description}> <Tooltip title={item.class_description}>

View File

@@ -1,8 +1,8 @@
/* /*
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:10:10 * @Date: 2026-02-03 14:10:10
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 14:10:10 * @Last Modified time: 2026-03-06 10:55:23
*/ */
/** /**
* Query parameters for ontology list pagination and filtering * Query parameters for ontology list pagination and filtering
@@ -38,6 +38,8 @@ export interface OntologyItem {
updated_at: number; updated_at: number;
/** Total count of classes in the scene */ /** Total count of classes in the scene */
classes_count: number; classes_count: number;
/** Whether this is the system default configuration */
is_system_default: boolean;
} }
/** /**
@@ -92,6 +94,7 @@ export interface OntologyClassData {
scene_description: string; scene_description: string;
/** Array of class items */ /** Array of class items */
items: OntologyClassItem[]; items: OntologyClassItem[];
is_system_default: boolean;
} }
/** /**