Merge pull request #486 from SuanmoSuanyangTechnology/release/v0.2.6
Release/v0.2.6
This commit is contained in:
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (Using Redis as broker)
|
# Celery (Using Redis as broker)
|
||||||
BROKER_URL=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BROKER=1
|
||||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BACKEND=2
|
||||||
|
|
||||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
|||||||
REDIS_DB=1
|
REDIS_DB=1
|
||||||
|
|
||||||
# Celery (使用Redis作为broker)
|
# Celery (使用Redis作为broker)
|
||||||
BROKER_URL=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BROKER=1
|
||||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
REDIS_DB_CELERY_BACKEND=2
|
||||||
|
|
||||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||||
SECRET_KEY=your-secret-key-here
|
SECRET_KEY=your-secret-key-here
|
||||||
|
|||||||
1
api/app/cache/__init__.py
vendored
1
api/app/cache/__init__.py
vendored
@@ -2,7 +2,6 @@
|
|||||||
Cache 缓存模块
|
Cache 缓存模块
|
||||||
|
|
||||||
提供各种缓存功能的统一入口
|
提供各种缓存功能的统一入口
|
||||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
|
||||||
"""
|
"""
|
||||||
from .memory import InterestMemoryCache
|
from .memory import InterestMemoryCache
|
||||||
|
|
||||||
|
|||||||
1
api/app/cache/memory/__init__.py
vendored
1
api/app/cache/memory/__init__.py
vendored
@@ -2,7 +2,6 @@
|
|||||||
Memory 缓存模块
|
Memory 缓存模块
|
||||||
|
|
||||||
提供记忆系统相关的缓存功能
|
提供记忆系统相关的缓存功能
|
||||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
|
||||||
"""
|
"""
|
||||||
from .interest_memory import InterestMemoryCache
|
from .interest_memory import InterestMemoryCache
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,54 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from celery.schedules import crontab
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery.schedules import crontab
|
from celery.schedules import crontab
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB 0)
|
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||||
# backend: 结果存储(使用 Redis DB 10)
|
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||||
|
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
|
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||||
|
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||||
|
# cannot be overridden by stray env vars.
|
||||||
|
# See: https://github.com/celery/celery/issues/4284
|
||||||
|
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
|
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||||
|
# integration and accidentally override our canonical URLs.
|
||||||
|
os.environ.pop("BROKER_URL", None)
|
||||||
|
os.environ.pop("RESULT_BACKEND", None)
|
||||||
|
os.environ.pop("CELERY_BROKER", None)
|
||||||
|
os.environ.pop("CELERY_BACKEND", None)
|
||||||
|
|
||||||
celery_app = Celery(
|
celery_app = Celery(
|
||||||
"redbear_tasks",
|
"redbear_tasks",
|
||||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
broker=_broker_url,
|
||||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
backend=_backend_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Celery app initialized",
|
||||||
|
extra={
|
||||||
|
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||||
|
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||||
|
},
|
||||||
|
)
|
||||||
# Default queue for unrouted tasks
|
# Default queue for unrouted tasks
|
||||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,29 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.repositories import knowledge_repository
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import task_service, workspace_service
|
from app.services import task_service, workspace_service
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -37,7 +38,7 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/health/status", response_model=ApiResponse)
|
@router.get("/health/status", response_model=ApiResponse)
|
||||||
async def get_health_status(
|
async def get_health_status(
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get latest health status written by Celery periodic task
|
Get latest health status written by Celery periodic task
|
||||||
@@ -55,8 +56,9 @@ async def get_health_status(
|
|||||||
|
|
||||||
@router.get("/download_log")
|
@router.get("/download_log")
|
||||||
async def download_log(
|
async def download_log(
|
||||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||||
current_user: User = Depends(get_current_user)
|
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download or stream agent service log file
|
Download or stream agent service log file
|
||||||
@@ -119,10 +121,10 @@ async def download_log(
|
|||||||
@router.post("/writer_service", response_model=ApiResponse)
|
@router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server(
|
async def write_server(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write service endpoint - processes write operations synchronously
|
Write service endpoint - processes write operations synchronously
|
||||||
@@ -161,13 +163,15 @@ async def write_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
else:
|
||||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
api_logger.warning(
|
||||||
|
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
else:
|
else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
api_logger.info(
|
||||||
|
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
@@ -195,10 +199,10 @@ async def write_server(
|
|||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
@@ -216,7 +220,8 @@ async def write_server_async(
|
|||||||
|
|
||||||
config_id = user_input.config_id
|
config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
api_logger.info(
|
||||||
|
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
@@ -254,9 +259,9 @@ async def write_server_async(
|
|||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def read_server(
|
async def read_server(
|
||||||
user_input: UserInput,
|
user_input: UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read service endpoint - processes read operations synchronously
|
Read service endpoint - processes read operations synchronously
|
||||||
@@ -292,7 +297,8 @@ async def read_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
api_logger.info(
|
||||||
|
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
user_input.end_user_id,
|
user_input.end_user_id,
|
||||||
@@ -306,7 +312,8 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
if str(user_input.search_switch) == "2":
|
||||||
retrieve_info = result['answer']
|
retrieve_info = result['answer']
|
||||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||||
|
user_input.end_user_id)
|
||||||
query = user_input.message
|
query = user_input.message
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
@@ -319,7 +326,7 @@ async def read_server(
|
|||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
if "信息不足,无法回答" in result['answer']:
|
if "信息不足,无法回答" in result['answer']:
|
||||||
result['answer']=retrieve_info
|
result['answer'] = retrieve_info
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -335,9 +342,10 @@ async def read_server(
|
|||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
async def file_update(
|
async def file_update(
|
||||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||||
model_id:str = Form(..., description="模型ID"),
|
model_id: str = Form(..., description="模型ID"),
|
||||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
文件上传接口 - 支持图片识别
|
文件上传接口 - 支持图片识别
|
||||||
@@ -350,9 +358,6 @@ async def file_update(
|
|||||||
Returns:
|
Returns:
|
||||||
文件处理结果
|
文件处理结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_gen = get_db() # get_db 通常是一个生成器
|
|
||||||
db = next(db_gen)
|
|
||||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
apiConfig: ModelApiKey = config.api_keys[0]
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
@@ -430,8 +435,8 @@ async def read_server_async(
|
|||||||
|
|
||||||
@router.get("/read_result/", response_model=ApiResponse)
|
@router.get("/read_result/", response_model=ApiResponse)
|
||||||
async def get_read_task_result(
|
async def get_read_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async read task
|
Get the status and result of an async read task
|
||||||
@@ -507,8 +512,8 @@ async def get_read_task_result(
|
|||||||
|
|
||||||
@router.get("/write_result/", response_model=ApiResponse)
|
@router.get("/write_result/", response_model=ApiResponse)
|
||||||
async def get_write_task_result(
|
async def get_write_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async write task
|
Get the status and result of an async write task
|
||||||
@@ -584,9 +589,9 @@ async def get_write_task_result(
|
|||||||
|
|
||||||
@router.post("/status_type", response_model=ApiResponse)
|
@router.post("/status_type", response_model=ApiResponse)
|
||||||
async def status_type(
|
async def status_type(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -629,9 +634,10 @@ async def status_type(
|
|||||||
|
|
||||||
@router.get("/stats/types", response_model=ApiResponse)
|
@router.get("/stats/types", response_model=ApiResponse)
|
||||||
async def get_knowledge_type_stats_api(
|
async def get_knowledge_type_stats_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||||
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
|||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
api_logger.info(
|
||||||
|
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
# 获取数据库会话
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 调用service层函数
|
# 调用service层函数
|
||||||
result = await memory_agent_service.get_knowledge_type_stats(
|
result = await memory_agent_service.get_knowledge_type_stats(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
|
|||||||
|
|
||||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||||
async def get_interest_distribution_by_user_api(
|
async def get_interest_distribution_by_user_api(
|
||||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签
|
获取指定用户的兴趣分布标签
|
||||||
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
async def get_user_profile_api(
|
async def get_user_profile_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -782,9 +783,9 @@ async def get_user_profile_api(
|
|||||||
|
|
||||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||||
async def get_end_user_connected_config(
|
async def get_end_user_connected_config(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取终端用户关联的记忆配置
|
获取终端用户关联的记忆配置
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
@@ -85,6 +85,7 @@ def create_config(
|
|||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -99,7 +100,29 @@ def create_config(
|
|||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
result = svc.create(payload)
|
result = svc.create(payload)
|
||||||
return success(data=result, msg="创建成功")
|
return success(data=result, msg="创建成功")
|
||||||
|
except ValueError as e:
|
||||||
|
err_str = str(e)
|
||||||
|
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||||
|
config_name = err_str.split(":", 1)[1]
|
||||||
|
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Create config failed: {err_str}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||||
|
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -289,7 +289,8 @@ async def extract_ontology(
|
|||||||
async def create_scene(
|
async def create_scene(
|
||||||
request: SceneCreateRequest,
|
request: SceneCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||||
):
|
):
|
||||||
"""创建本体场景
|
"""创建本体场景
|
||||||
|
|
||||||
@@ -360,8 +361,18 @@ async def create_scene(
|
|||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
err_str = str(e)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||||
|
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||||
@@ -661,7 +672,8 @@ async def get_scenes(
|
|||||||
async def create_class(
|
async def create_class(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||||
):
|
):
|
||||||
"""创建本体类型
|
"""创建本体类型
|
||||||
|
|
||||||
@@ -676,7 +688,7 @@ async def create_class(
|
|||||||
ApiResponse: 包含创建的类型信息
|
ApiResponse: 包含创建的类型信息
|
||||||
"""
|
"""
|
||||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||||
return await create_class_handler(request, db, current_user)
|
return await create_class_handler(request, db, current_user, x_language_type)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends, Header
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
@@ -58,7 +58,7 @@ async def scenes_handler(
|
|||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
scene_name: Optional[str] = None,
|
scene_name: Optional[str] = None,
|
||||||
page: Optional[int] = None,
|
page: Optional[int] = None,
|
||||||
page_size: Optional[int] = None,
|
pagesize: Optional[int] = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -71,14 +71,14 @@ async def scenes_handler(
|
|||||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
current_user: 当前用户
|
current_user: 当前用户
|
||||||
"""
|
"""
|
||||||
operation = "search" if scene_name else "list"
|
operation = "search" if scene_name else "list"
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Scene {operation} requested by user {current_user.id}, "
|
f"Scene {operation} requested by user {current_user.id}, "
|
||||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -105,13 +105,13 @@ async def scenes_handler(
|
|||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
if pagesize is not None and pagesize < 1:
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
# 如果只提供了page或pagesize中的一个,返回错误
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
# 模糊搜索场景(支持分页)
|
# 模糊搜索场景(支持分页)
|
||||||
@@ -119,17 +119,15 @@ async def scenes_handler(
|
|||||||
total = len(scenes)
|
total = len(scenes)
|
||||||
|
|
||||||
# 如果提供了分页参数,进行分页处理
|
# 如果提供了分页参数,进行分页处理
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
start_idx = (page - 1) * page_size
|
start_idx = (page - 1) * pagesize
|
||||||
end_idx = start_idx + page_size
|
end_idx = start_idx + pagesize
|
||||||
scenes = scenes[start_idx:end_idx]
|
scenes = scenes[start_idx:end_idx]
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -141,17 +139,16 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num
|
classes_count=type_num,
|
||||||
|
is_system_default=scene.is_system_default
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
# 计算是否有下一页
|
hasnext = (page * pagesize) < total
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=page_size,
|
pagesize=pagesize,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -165,28 +162,25 @@ async def scenes_handler(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 获取所有场景(支持分页)
|
# 获取所有场景(支持分页)
|
||||||
# 验证分页参数
|
|
||||||
if page is not None and page < 1:
|
if page is not None and page < 1:
|
||||||
api_logger.warning(f"Invalid page number: {page}")
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
if pagesize is not None and pagesize < 1:
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
# 如果只提供了page或pagesize中的一个,返回错误
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
items = []
|
items = []
|
||||||
for scene in scenes:
|
for scene in scenes:
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
items.append(SceneResponse(
|
items.append(SceneResponse(
|
||||||
@@ -198,17 +192,16 @@ async def scenes_handler(
|
|||||||
workspace_id=scene.workspace_id,
|
workspace_id=scene.workspace_id,
|
||||||
created_at=scene.created_at,
|
created_at=scene.created_at,
|
||||||
updated_at=scene.updated_at,
|
updated_at=scene.updated_at,
|
||||||
classes_count=type_num
|
classes_count=type_num,
|
||||||
|
is_system_default=scene.is_system_default
|
||||||
))
|
))
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
# 构建响应(包含分页信息)
|
||||||
if page is not None and page_size is not None:
|
if page is not None and pagesize is not None:
|
||||||
# 计算是否有下一页
|
hasnext = (page * pagesize) < total
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
pagination_info = PaginationInfo(
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=page_size,
|
pagesize=pagesize,
|
||||||
total=total,
|
total=total,
|
||||||
hasnext=hasnext
|
hasnext=hasnext
|
||||||
)
|
)
|
||||||
@@ -238,7 +231,8 @@ async def scenes_handler(
|
|||||||
async def create_class_handler(
|
async def create_class_handler(
|
||||||
request: ClassCreateRequest,
|
request: ClassCreateRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
x_language_type: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||||
|
|
||||||
@@ -271,8 +265,11 @@ async def create_class_handler(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if count == 1:
|
if count == 1:
|
||||||
# 单个创建
|
# 单个创建 - 先检查重名
|
||||||
class_data = classes_data[0]
|
class_data = classes_data[0]
|
||||||
|
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||||
ontology_class = service.create_class(
|
ontology_class = service.create_class(
|
||||||
scene_id=request.scene_id,
|
scene_id=request.scene_id,
|
||||||
class_name=class_data["class_name"],
|
class_name=class_data["class_name"],
|
||||||
@@ -330,12 +327,36 @@ async def create_class_handler(
|
|||||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
err_str = str(e)
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||||
|
class_name = err_str.split(":", 1)[1]
|
||||||
|
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
err_str = str(e)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||||
|
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||||
|
from app.core.language_utils import get_language_from_header
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
lang = get_language_from_header(x_language_type)
|
||||||
|
class_name = request.classes[0].class_name if request.classes else ""
|
||||||
|
if lang == "en":
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||||
|
else:
|
||||||
|
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||||
|
return JSONResponse(status_code=400, content=msg)
|
||||||
|
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||||
@@ -615,6 +636,7 @@ async def classes_handler(
|
|||||||
scene_id=scene_uuid,
|
scene_id=scene_uuid,
|
||||||
scene_name=scene.scene_name,
|
scene_name=scene.scene_name,
|
||||||
scene_description=scene.scene_description,
|
scene_description=scene.scene_description,
|
||||||
|
is_system_default=scene.is_system_default,
|
||||||
items=items
|
items=items
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -190,8 +190,10 @@ class Settings:
|
|||||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
# 详见 docs/celery-env-bug-report.md
|
||||||
|
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
|
||||||
|
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
|
||||||
|
|
||||||
# SMTP Email Configuration
|
# SMTP Email Configuration
|
||||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
structured = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
response_content = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
response_content = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
|
|||||||
@@ -6,31 +6,26 @@ import os
|
|||||||
# ===== 第三方库 =====
|
# ===== 第三方库 =====
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
|
|
||||||
from app.schemas import model_schema
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from app.services.model_service import ModelConfigService
|
|
||||||
|
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
|
||||||
COUNTState,
|
|
||||||
ReadState,
|
|
||||||
deduplicate_entries,
|
|
||||||
merge_to_key_value_pairs,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||||
create_hybrid_retrieval_tool_sync,
|
create_hybrid_retrieval_tool_sync,
|
||||||
create_time_retrieval_tool,
|
create_time_retrieval_tool,
|
||||||
extract_tool_message_content,
|
extract_tool_message_content,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
|
ReadState,
|
||||||
|
deduplicate_entries,
|
||||||
|
merge_to_key_value_pairs,
|
||||||
|
)
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db = next(get_db())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
@@ -50,10 +45,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
text_parts.append(item)
|
text_parts.append(item)
|
||||||
|
|
||||||
|
|
||||||
return '\n'.join(text_parts).strip()
|
return '\n'.join(text_parts).strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
模型信息
|
模型信息
|
||||||
'''
|
'''
|
||||||
|
|
||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
end_user_id=state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original=state.get('data', '')
|
original = state.get('data', '')
|
||||||
problem_list=[]
|
problem_list = []
|
||||||
for key,values in problem_extension.items():
|
for key, values in problem_extension.items():
|
||||||
for data in values:
|
for data in values:
|
||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# 创建异步任务处理单个问题
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
send_verify = []
|
send_verify = []
|
||||||
for i, j in zip(keys, val, strict=False):
|
for i, j in zip(keys, val, strict=False):
|
||||||
if j!=['']:
|
if j != ['']:
|
||||||
send_verify.append({
|
send_verify.append({
|
||||||
"Query_small": i,
|
"Query_small": i,
|
||||||
"Answer_Small": j
|
"Answer_Small": j
|
||||||
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve':dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取end_user_id
|
# 从state中获取end_user_id
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
return await llm_infomation(state)
|
return await llm_infomation(state)
|
||||||
|
|
||||||
llm_config = await get_llm_info()
|
llm_config = await get_llm_info()
|
||||||
api_key_obj = llm_config.api_keys[0]
|
api_key_obj = llm_config.api_keys[0]
|
||||||
api_key = api_key_obj.api_key
|
api_key = api_key_obj.api_key
|
||||||
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
async with SEMAPHORE: # 限制并发
|
async with SEMAPHORE: # 限制并发
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||||
|
question)
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
# json.dump(dup_databases, f, indent=4)
|
# json.dump(dup_databases, f, indent=4)
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve': dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
|
|||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db_session = next(get_db())
|
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""总结节点服务类"""
|
"""总结节点服务类"""
|
||||||
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
@@ -51,10 +51,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -62,20 +64,23 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
|
||||||
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||||
|
search_mode) -> str:
|
||||||
"""
|
"""
|
||||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||||
"""
|
"""
|
||||||
@@ -99,13 +104,14 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务进行结构化输出
|
# 使用优化的LLM服务进行结构化输出
|
||||||
structured = await summary_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await summary_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=response_model,
|
system_prompt=system_prompt,
|
||||||
fallback_value=None
|
response_model=response_model,
|
||||||
)
|
fallback_value=None
|
||||||
|
)
|
||||||
# 验证结构化响应
|
# 验证结构化响应
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
@@ -157,7 +163,8 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
logger.error(f"Fallback也失败: {fallback_error}")
|
logger.error(f"Fallback也失败: {fallback_error}")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|
||||||
|
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|
||||||
storage_type=state.get("storage_type",'')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||||
data=state.get("data", '')
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
|
data = state.get("data", '')
|
||||||
input_summary = {
|
input_summary = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
retrieve={
|
retrieve = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id,
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
"_intermediate": {
|
"_intermediate": {
|
||||||
"type": "retrieval_summary",
|
"type": "retrieval_summary",
|
||||||
"title":"快速检索",
|
"title": "快速检索",
|
||||||
"summary": aimessages,
|
"summary": aimessages,
|
||||||
"query": data,
|
"query": data,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return input_summary,retrieve
|
return input_summary, retrieve
|
||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
storage_type=state.get("storage_type",'')
|
storage_type = state.get("storage_type", '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
data=state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id=state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
history = await summary_history( state)
|
history = await summary_history(state)
|
||||||
search_params = {
|
search_params = {
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type!="rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||||
|
memory_config=memory_config)
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||||
retrieve_info, question, raw_results = "", data, []
|
retrieve_info, question, raw_results = "", data, []
|
||||||
try:
|
try:
|
||||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||||
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||||
summary = summary_result[0]
|
summary = summary_result[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||||
summary= {
|
summary = {
|
||||||
"status": "fail",
|
"status": "fail",
|
||||||
"summary_result": "信息不足,无法回答",
|
"summary_result": "信息不足,无法回答",
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
except Exception:
|
except Exception:
|
||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|
||||||
retrieve=state.get("retrieve", '')
|
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||||
history = await summary_history( state)
|
retrieve = state.get("retrieve", '')
|
||||||
|
history = await summary_history(state)
|
||||||
import json
|
import json
|
||||||
with open("检索.json","w",encoding='utf-8') as f:
|
with open("检索.json", "w", encoding='utf-8') as f:
|
||||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||||
retrieve=retrieve.get("Expansion_issue", [])
|
retrieve = retrieve.get("Expansion_issue", [])
|
||||||
start=time.time()
|
start = time.time()
|
||||||
retrieve_info_str=[]
|
retrieve_info_str = []
|
||||||
for data in retrieve:
|
for data in retrieve:
|
||||||
if data=='':
|
if data == '':
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
else:
|
else:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='Answer_Small':
|
if key == 'Answer_Small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str.append(i)
|
retrieve_info_str.append(i)
|
||||||
retrieve_info_str=list(set(retrieve_info_str))
|
retrieve_info_str = list(set(retrieve_info_str))
|
||||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -290,29 +302,29 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState)-> ReadState:
|
async def Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify=state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
verify_expansion_issue=verify.get("verified_data", '')
|
verify_expansion_issue = verify.get("verified_data", '')
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
for data in verify_expansion_issue:
|
for data in verify_expansion_issue:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str+=i+'\n'
|
retrieve_info_str += i + '\n'
|
||||||
history=await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages=await summary_llm(state,history,data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
|
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
|
||||||
storage_type=state.get("storage_type", '')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
async def Summary_fails(state: ReadState) -> ReadState:
|
||||||
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify = state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
|||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages = await summary_llm(state, history, data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
result= {
|
result = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
return {"summary":result}
|
return {"summary": result}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""验证节点服务类"""
|
"""验证节点服务类"""
|
||||||
|
|
||||||
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""处理验证结果并生成输出格式"""
|
"""处理验证结果并生成输出格式"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Verify_result
|
return Verify_result
|
||||||
|
|
||||||
|
|
||||||
async def Verify(state: ReadState):
|
async def Verify(state: ReadState):
|
||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
|
|||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
logger.info(
|
||||||
|
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||||
|
|
||||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||||
@@ -100,23 +106,24 @@ async def Verify(state: ReadState):
|
|||||||
try:
|
try:
|
||||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||||
import asyncio
|
|
||||||
structured = await asyncio.wait_for(
|
with get_db_context() as db_session:
|
||||||
verification_service.call_llm_structured(
|
structured = await asyncio.wait_for(
|
||||||
state=state,
|
verification_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=VerificationResult,
|
system_prompt=system_prompt,
|
||||||
fallback_value={
|
response_model=VerificationResult,
|
||||||
"query": content,
|
fallback_value={
|
||||||
"history": history if isinstance(history, list) else [],
|
"query": content,
|
||||||
"expansion_issue": [],
|
"history": history if isinstance(history, list) else [],
|
||||||
"split_result": "failed",
|
"expansion_issue": [],
|
||||||
"reason": "验证失败或超时"
|
"split_result": "failed",
|
||||||
}
|
"reason": "验证失败或超时"
|
||||||
),
|
}
|
||||||
timeout=150.0 # 150秒超时
|
),
|
||||||
)
|
timeout=150.0 # 150秒超时
|
||||||
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""创建并返回 LangGraph 工作流"""
|
"""创建并返回 LangGraph 工作流"""
|
||||||
@@ -62,7 +60,6 @@ async def make_read_graph():
|
|||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|
||||||
'''-----'''
|
'''-----'''
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
@@ -76,6 +73,7 @@ async def make_read_graph():
|
|||||||
finally:
|
finally:
|
||||||
print("工作流创建完成")
|
print("工作流创建完成")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "昨天有什么好看的电影"
|
message = "昨天有什么好看的电影"
|
||||||
@@ -92,13 +90,15 @@ async def main():
|
|||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
"end_user_id": end_user_id
|
||||||
|
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"memory_config": memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
summary = ''
|
summary = ''
|
||||||
@@ -141,7 +141,6 @@ async def main():
|
|||||||
if verify_n and verify_n != [] and verify_n != {}:
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
_intermediate_outputs.append(verify_n)
|
_intermediate_outputs.append(verify_n)
|
||||||
|
|
||||||
|
|
||||||
# Summary 节点
|
# Summary 节点
|
||||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||||
if summary_n and summary_n != [] and summary_n != {}:
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
@@ -165,13 +164,16 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
end=time.time()
|
end = time.time()
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
print(f"总耗时: {end-start}s")
|
print(f"总耗时: {end - start}s")
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
class LLMClientPool:
|
|
||||||
"""LLM客户端连接池"""
|
|
||||||
|
|
||||||
def __init__(self, max_size: int = 5):
|
|
||||||
self.max_size = max_size
|
|
||||||
self.pools: Dict[str, asyncio.Queue] = {}
|
|
||||||
self.active_clients: Dict[str, int] = {}
|
|
||||||
|
|
||||||
async def get_client(self, llm_model_id: str):
|
|
||||||
"""获取LLM客户端"""
|
|
||||||
if llm_model_id not in self.pools:
|
|
||||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
|
||||||
self.active_clients[llm_model_id] = 0
|
|
||||||
|
|
||||||
pool = self.pools[llm_model_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试从池中获取客户端
|
|
||||||
client = pool.get_nowait()
|
|
||||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
# 池为空,创建新客户端
|
|
||||||
if self.active_clients[llm_model_id] < self.max_size:
|
|
||||||
db_session = next(get_db())
|
|
||||||
client = get_llm_client_fast(llm_model_id, db_session)
|
|
||||||
self.active_clients[llm_model_id] += 1
|
|
||||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
else:
|
|
||||||
# 等待可用客户端
|
|
||||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
|
||||||
return await pool.get()
|
|
||||||
|
|
||||||
async def return_client(self, llm_model_id: str, client):
|
|
||||||
"""归还LLM客户端到池中"""
|
|
||||||
if llm_model_id in self.pools:
|
|
||||||
try:
|
|
||||||
self.pools[llm_model_id].put_nowait(client)
|
|
||||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# 池已满,丢弃客户端
|
|
||||||
self.active_clients[llm_model_id] -= 1
|
|
||||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
|
||||||
|
|
||||||
# 全局客户端池
|
|
||||||
llm_client_pool = LLMClientPool()
|
|
||||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
|||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.models import AppRelease
|
from app.models import AppRelease
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||||
"""准备 Agent(公共逻辑)
|
"""准备 Agent(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
|
|||||||
if not agent_id:
|
if not agent_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||||
|
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
release = db.query(AppRelease).filter(
|
release = db.query(AppRelease).filter(
|
||||||
AppRelease.id == agent_id
|
AppRelease.id == agent_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
draft_service = AgentRunService(db)
|
|
||||||
|
|
||||||
return draft_service, release, message
|
|
||||||
|
return release, message
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""非流式执行
|
"""非流式执行
|
||||||
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
状态更新字典
|
状态更新字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||||
|
with get_db_context() as db:
|
||||||
|
draft_service = AgentRunService(db)
|
||||||
|
|
||||||
# 执行 Agent(非流式)
|
# 执行 Agent(非流式)
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=release.config,
|
agent_config=release.config,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
user_id=state.get("user_id"),
|
user_id=state.get("user_id"),
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
|
|
||||||
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
|
|||||||
Yields:
|
Yields:
|
||||||
流式事件字典
|
流式事件字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
with get_db_context() as db:
|
||||||
|
draft_service = AgentRunService(db)
|
||||||
# 执行 Agent(流式)
|
# 执行 Agent(流式)
|
||||||
async for chunk in draft_service.run_stream(
|
async for chunk in draft_service.run_stream(
|
||||||
agent_config=release.config,
|
agent_config=release.config,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
user_id=state.get("user_id"),
|
user_id=state.get("user_id"),
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
):
|
):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
content = chunk.get("content", "")
|
content = chunk.get("content", "")
|
||||||
full_response += content
|
full_response += content
|
||||||
|
|
||||||
# 流式返回每个 chunk
|
# 流式返回每个 chunk
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"type": "chunk",
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"full_content": full_response,
|
"full_content": full_response,
|
||||||
"meta_data": chunk.get("meta_data", {})
|
"meta_data": chunk.get("meta_data", {})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||||
|
|
||||||
|
|||||||
@@ -374,7 +374,7 @@ class OntologySceneRepository:
|
|||||||
|
|
||||||
count = self.db.query(OntologyScene).filter(
|
count = self.db.query(OntologyScene).filter(
|
||||||
OntologyScene.scene_id == scene_id,
|
OntologyScene.scene_id == scene_id,
|
||||||
OntologyScene.workspace_id == workspace_id
|
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
is_owner = count > 0
|
is_owner = count > 0
|
||||||
|
|||||||
@@ -116,8 +116,8 @@ class ModelApiKeyBase(BaseModel):
|
|||||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
priority: str = Field("1", description="优先级", max_length=10)
|
priority: str = Field("1", description="优先级", max_length=10)
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
|
|||||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||||
classes_count: int = Field(0, description="类型数量")
|
classes_count: int = Field(0, description="类型数量")
|
||||||
|
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
|
|||||||
scene_id: UUID = Field(..., description="所属场景ID")
|
scene_id: UUID = Field(..., description="所属场景ID")
|
||||||
scene_name: str = Field(..., description="场景名称")
|
scene_name: str = Field(..., description="场景名称")
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
|
|||||||
"""
|
"""
|
||||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
with get_db_context() as db:
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
memory_content = asyncio.run(
|
memory_content = asyncio.run(
|
||||||
MemoryAgentService().read_memory(
|
MemoryAgentService().read_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
|
|||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
if memory_content:
|
if memory_content:
|
||||||
memory_content = memory_content['answer']
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
|||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
|||||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=True,
|
||||||
duration=duration, details={"message_length": len(message)})
|
duration=duration, details={"message_length": len(message)})
|
||||||
return context
|
return context
|
||||||
else:
|
else:
|
||||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(f"写入失败: {messages}")
|
raise ValueError(f"写入失败: {messages}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||||
"""Extract tool call information from event"""
|
"""Extract tool call information from event"""
|
||||||
last_message = event["messages"][-1]
|
last_message = event["messages"][-1]
|
||||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
|||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
@@ -351,9 +350,9 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
print(langchain_messages)
|
print(langchain_messages)
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
@@ -375,29 +374,28 @@ class MemoryAgentService:
|
|||||||
contents = massages.get('write_result')
|
contents = massages.get('write_result')
|
||||||
# Convert messages back to string for logging
|
# Convert messages back to string for logging
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||||
|
contents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
config_id: Optional[uuid.UUID]|int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str) -> Dict:
|
user_rag_memory_id: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Process read operation with config_id
|
Process read operation with config_id
|
||||||
|
|
||||||
@@ -425,7 +423,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ori_message= message
|
ori_message = message
|
||||||
|
|
||||||
# Resolve config_id and workspace_id
|
# Resolve config_id and workspace_id
|
||||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
# Use a separate database session to avoid transaction failures
|
# Use a separate database session to avoid transaction failures
|
||||||
@@ -576,7 +574,8 @@ class MemoryAgentService:
|
|||||||
raw_results = intermediate.get('raw_results', {})
|
raw_results = intermediate.get('raw_results', {})
|
||||||
try:
|
try:
|
||||||
reranked_results = raw_results.get('reranked_results', [])
|
reranked_results = raw_results.get('reranked_results', [])
|
||||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
statements = [statement['statement'] for statement in
|
||||||
|
reranked_results.get('statements', [])]
|
||||||
except Exception:
|
except Exception:
|
||||||
statements = []
|
statements = []
|
||||||
|
|
||||||
@@ -602,7 +601,8 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
logger.debug(
|
||||||
|
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||||
|
|
||||||
except Exception as save_error:
|
except Exception as save_error:
|
||||||
# 保存失败不应该影响主流程,只记录错误
|
# 保存失败不应该影响主流程,只记录错误
|
||||||
@@ -610,7 +610,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# Log successful operation
|
# Log successful operation
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
logger.info(
|
||||||
|
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get standardized message list from user input.
|
Get standardized message list from user input.
|
||||||
@@ -665,7 +665,8 @@ class MemoryAgentService:
|
|||||||
for idx, msg in enumerate(user_input.messages):
|
for idx, msg in enumerate(user_input.messages):
|
||||||
if not isinstance(msg, dict):
|
if not isinstance(msg, dict):
|
||||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||||
|
|
||||||
if 'role' not in msg:
|
if 'role' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||||
@@ -673,7 +674,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
if 'content' not in msg:
|
if 'content' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||||
|
|
||||||
if msg['role'] not in ['user', 'assistant']:
|
if msg['role'] not in ['user', 'assistant']:
|
||||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||||
@@ -687,11 +689,11 @@ class MemoryAgentService:
|
|||||||
return user_input.messages
|
return user_input.messages
|
||||||
|
|
||||||
async def classify_message_type(
|
async def classify_message_type(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: UUID,
|
config_id: UUID,
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -719,14 +721,15 @@ class MemoryAgentService:
|
|||||||
status = await status_typle(message, memory_config.llm_model_id)
|
status = await status_typle(message, memory_config.llm_model_id)
|
||||||
logger.debug(f"Message type: {status}")
|
logger.debug(f"Message type: {status}")
|
||||||
return status
|
return status
|
||||||
|
|
||||||
async def generate_summary_from_retrieve(
|
async def generate_summary_from_retrieve(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
retrieve_info: str,
|
retrieve_info: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
query: str,
|
query: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
db: Session
|
db: Session
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
基于检索信息、历史对话和查询生成最终答案
|
基于检索信息、历史对话和查询生成最终答案
|
||||||
@@ -805,13 +808,12 @@ class MemoryAgentService:
|
|||||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||||
return "信息不足,无法回答。"
|
return "信息不足,无法回答。"
|
||||||
|
|
||||||
|
|
||||||
async def get_knowledge_type_stats(
|
async def get_knowledge_type_stats(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
db: Session,
|
||||||
only_active: bool = True,
|
end_user_id: Optional[str] = None,
|
||||||
current_workspace_id: Optional[uuid.UUID] = None,
|
only_active: bool = True,
|
||||||
db: Session = None
|
current_workspace_id: Optional[uuid.UUID] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
统计知识库类型分布,包含:
|
统计知识库类型分布,包含:
|
||||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 1. 统计 PostgreSQL 中的知识库类型
|
# 1. 统计 PostgreSQL 中的知识库类型
|
||||||
try:
|
try:
|
||||||
if db is None:
|
|
||||||
from app.db import get_db
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 初始化所有标准类型为 0
|
# 初始化所有标准类型为 0
|
||||||
for kb_type in KnowledgeType:
|
for kb_type in KnowledgeType:
|
||||||
result[kb_type.value] = 0
|
result[kb_type.value] = 0
|
||||||
@@ -881,21 +878,19 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 3. 计算知识库类型总和(不包括 memory)
|
# 3. 计算知识库类型总和(不包括 memory)
|
||||||
result["total"] = (
|
result["total"] = (
|
||||||
result.get("General", 0) +
|
result.get("General", 0) +
|
||||||
result.get("Web", 0) +
|
result.get("Web", 0) +
|
||||||
result.get("Third-party", 0) +
|
result.get("Third-party", 0) +
|
||||||
result.get("Folder", 0)
|
result.get("Folder", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_interest_distribution_by_user(
|
async def get_interest_distribution_by_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签。
|
获取指定用户的兴趣分布标签。
|
||||||
@@ -921,13 +916,12 @@ class MemoryAgentService:
|
|||||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_profile(
|
async def get_user_profile(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user_id: Optional[str] = None,
|
current_user_id: Optional[str] = None,
|
||||||
llm_id: Optional[str] = None,
|
llm_id: Optional[str] = None,
|
||||||
db: Session = None
|
db: Session = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 定义标签提取的结构
|
# 定义标签提取的结构
|
||||||
class UserTags(BaseModel):
|
class UserTags(BaseModel):
|
||||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
tags: list[str] = Field(...,
|
||||||
|
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
ValueError: 当终端用户不存在或应用未发布时
|
ValueError: 当终端用户不存在或应用未发布时
|
||||||
"""
|
"""
|
||||||
import json as json_module
|
import json as json_module
|
||||||
import uuid
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
"workspace_id": str(app.workspace_id)
|
"workspace_id": str(app.workspace_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
logger.info(
|
||||||
|
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,45 +1,42 @@
|
|||||||
# 修改 memory_konwledges_server.py 文件
|
# 修改 memory_konwledges_server.py 文件
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from fastapi import HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.schemas import file_schema, document_schema
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
|
||||||
from app.models.document_model import Document
|
from app.models.document_model import Document
|
||||||
import uuid
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from fastapi import HTTPException, status
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
from app.schemas.file_schema import CustomTextFileCreate
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
from app.celery_app import celery_app
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
|
||||||
from app.db import get_db
|
|
||||||
# 创建一个简单的用户类用于测试
|
# 创建一个简单的用户类用于测试
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
|
||||||
class ChunkCreate(BaseModel):
|
class ChunkCreate(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class SimpleUser:
|
class SimpleUser:
|
||||||
def __init__(self, user_id: str):
|
def __init__(self, user_id: str):
|
||||||
# 确保ID是UUID类型
|
# 确保ID是UUID类型
|
||||||
self.id = user_id
|
self.id = user_id
|
||||||
self.username = user_id
|
self.username = user_id
|
||||||
|
|
||||||
'''解析'''
|
|
||||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||||
"""
|
"""
|
||||||
解析指定文档
|
解析指定文档
|
||||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
|||||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
'''获取块ID'''
|
|
||||||
async def get_document_chunks(
|
async def get_document_chunks(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
|||||||
|
|
||||||
return success(data=result, msg="文档块列表查询成功")
|
return success(data=result, msg="文档块列表查询成功")
|
||||||
|
|
||||||
'''查找文档ID'''
|
|
||||||
def find_document_id_by_kb_and_filename(
|
def find_document_id_by_kb_and_filename(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
'''获取知识库ID'''
|
|
||||||
def find_documents_by_kb_id(
|
def find_documents_by_kb_id(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
''''上传文件'''
|
|
||||||
async def memory_konwledges_up(
|
async def memory_konwledges_up(
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
parent_id: str,
|
parent_id: str,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session,
|
||||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
current_user: SimpleUser,
|
||||||
):
|
):
|
||||||
# 如果没有提供current_user,则创建一个默认的
|
|
||||||
if current_user is None:
|
|
||||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
|
||||||
|
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
print(f"file size: {file_size} byte")
|
print(f"file size: {file_size} byte")
|
||||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
|||||||
|
|
||||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||||
|
|
||||||
'''添加新块'''
|
|
||||||
|
|
||||||
|
|
||||||
async def create_document_chunk(
|
async def create_document_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
|||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return success(data=chunk, msg="文档块创建成功")
|
||||||
|
|
||||||
|
|
||||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
db_gen = get_db()
|
with get_db_context() as db:
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
try:
|
|
||||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||||
current_user = SimpleUser(user_rag_memory_id)
|
current_user = SimpleUser(user_rag_memory_id)
|
||||||
# 检查文档是否已存在
|
# 检查文档是否已存在
|
||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======',document)
|
print('======', document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
if document is not None:
|
if document is not None:
|
||||||
# 文档已存在,直接添加新块
|
# 文档已存在,直接添加新块
|
||||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
else:
|
else:
|
||||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||||
return result
|
return result
|
||||||
finally:
|
|
||||||
# 确保数据库会话被关闭
|
|
||||||
db.close()
|
|
||||||
@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# --- Create ---
|
# --- Create ---
|
||||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||||
|
# 业务层检查同一工作空间下是否已存在同名配置
|
||||||
|
if params.workspace_id and params.config_name:
|
||||||
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
existing = (
|
||||||
|
self.db.query(MemoryConfig)
|
||||||
|
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
|
||||||
|
|
||||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||||
configs = self._get_workspace_configs(params.workspace_id)
|
configs = self._get_workspace_configs(params.workspace_id)
|
||||||
@@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||||
"scene_name": scene_name, # 新增:场景名称
|
"scene_name": scene_name, # 新增:场景名称
|
||||||
|
"is_system_default": config.is_default, # 是否为系统默认配置
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
|
|||||||
@@ -116,27 +116,15 @@ class ModelConfigService:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# dashscope 的 omni 模型需要使用 compatible-mode
|
model_config = RedBearModelConfig(
|
||||||
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
|
model_name=model_name,
|
||||||
if not api_base:
|
provider=provider,
|
||||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
api_key=api_key,
|
||||||
model_config = RedBearModelConfig(
|
base_url=api_base,
|
||||||
model_name=model_name,
|
is_omni=is_omni,
|
||||||
provider=ModelProvider.OPENAI,
|
temperature=0.7,
|
||||||
api_key=api_key,
|
max_tokens=100
|
||||||
base_url=api_base,
|
)
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_config = RedBearModelConfig(
|
|
||||||
model_name=model_name,
|
|
||||||
provider=provider,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=api_base,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据模型类型选择不同的验证方式
|
# 根据模型类型选择不同的验证方式
|
||||||
model_type_lower = model_type.lower()
|
model_type_lower = model_type.lower()
|
||||||
@@ -493,6 +481,9 @@ class ModelApiKeyService:
|
|||||||
if not model_config:
|
if not model_config:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
data.is_omni = model_config.is_omni
|
||||||
|
data.capability = model_config.capability
|
||||||
|
|
||||||
# 从ModelBase获取model_name
|
# 从ModelBase获取model_name
|
||||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||||
|
|
||||||
@@ -550,8 +541,8 @@ class ModelApiKeyService:
|
|||||||
provider=data.provider,
|
provider=data.provider,
|
||||||
api_key=data.api_key,
|
api_key=data.api_key,
|
||||||
api_base=data.api_base,
|
api_base=data.api_base,
|
||||||
capability=data.capability if data.capability is not None else model_config.capability,
|
capability=data.capability,
|
||||||
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
|
is_omni=data.is_omni,
|
||||||
config=data.config,
|
config=data.config,
|
||||||
is_active=data.is_active,
|
is_active=data.is_active,
|
||||||
priority=data.priority
|
priority=data.priority
|
||||||
@@ -574,6 +565,10 @@ class ModelApiKeyService:
|
|||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
if api_key_data.is_omni is None:
|
||||||
|
api_key_data.is_omni = model_config.is_omni
|
||||||
|
if api_key_data.capability is None:
|
||||||
|
api_key_data.capability = model_config.capability
|
||||||
|
|
||||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||||
existing_key = db.query(ModelApiKey).join(
|
existing_key = db.query(ModelApiKey).join(
|
||||||
@@ -616,7 +611,7 @@ class ModelApiKeyService:
|
|||||||
api_base=api_key_data.api_base,
|
api_base=api_key_data.api_base,
|
||||||
model_type=model_config.type,
|
model_type=model_config.type,
|
||||||
test_message="Hello",
|
test_message="Hello",
|
||||||
is_omni=model_config.is_omni
|
is_omni=api_key_data.is_omni
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
if not validation_result["valid"]:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
from app.services.memory_short_service import ShortService
|
from app.services.memory_short_service import ShortService
|
||||||
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
|
|
||||||
from app.core.language_utils import validate_language
|
from app.core.language_utils import validate_language
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||||
from app.db import get_db
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
# 验证语言参数
|
# 验证语言参数
|
||||||
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
# 获取数据库会话并查询用户信息
|
# 获取数据库会话并查询用户信息
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||||
if end_user and end_user.other_name:
|
if end_user and end_user.other_name:
|
||||||
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
|||||||
for workspace in workspaces:
|
for workspace in workspaces:
|
||||||
if workspace.storage_type == 'neo4j':
|
if workspace.storage_type == 'neo4j':
|
||||||
_ensure_default_memory_config(db, workspace)
|
_ensure_default_memory_config(db, workspace)
|
||||||
|
_ensure_default_ontology_scenes(db, workspace)
|
||||||
|
|
||||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||||
return workspaces
|
return workspaces
|
||||||
@@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||||
|
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||||
|
|
||||||
|
Checks whether any is_system_default scene exists for the workspace.
|
||||||
|
If not, runs the DefaultOntologyInitializer to create them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
workspace: The workspace to check
|
||||||
|
"""
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
|
# 幂等检查:是否已存在系统默认场景
|
||||||
|
existing = db.query(OntologyScene).filter(
|
||||||
|
OntologyScene.workspace_id == workspace.id,
|
||||||
|
OntologyScene.is_system_default.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return
|
||||||
|
|
||||||
|
business_logger.info(
|
||||||
|
f"Workspace {workspace.id} missing default ontology scenes, creating them"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
initializer = DefaultOntologyInitializer(db)
|
||||||
|
success, error_msg = initializer.initialize_default_scenes(
|
||||||
|
workspace.id, language="zh"
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
db.commit()
|
||||||
|
business_logger.info(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景成功"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
business_logger.warning(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
business_logger.error(
|
||||||
|
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_default_memory_config(
|
def _create_default_memory_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ REDIS_DB=
|
|||||||
REDIS_PASSWORD=password
|
REDIS_PASSWORD=password
|
||||||
|
|
||||||
#celery
|
#celery
|
||||||
BROKER_URL=
|
# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
RESULT_BACKEND=
|
# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
|
||||||
CELERY_BROKER=
|
REDIS_DB_CELERY_BROKER=
|
||||||
CELERY_BACKEND=
|
REDIS_DB_CELERY_BACKEND=
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
# Interval in hours for regenerating memory insight and user summary cache
|
# Interval in hours for regenerating memory insight and user summary cache
|
||||||
|
|||||||
@@ -440,7 +440,6 @@ export const en = {
|
|||||||
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
|
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
|
||||||
publicApiCannotRefreshToken: 'Public API cannot refresh token',
|
publicApiCannotRefreshToken: 'Public API cannot refresh token',
|
||||||
refreshTokenNotExist: 'Refresh token does not exist',
|
refreshTokenNotExist: 'Refresh token does not exist',
|
||||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
|
|
||||||
reset: 'Reset',
|
reset: 'Reset',
|
||||||
refresh: 'Refresh',
|
refresh: 'Refresh',
|
||||||
return: 'Return',
|
return: 'Return',
|
||||||
@@ -454,6 +453,7 @@ export const en = {
|
|||||||
prevStep: 'Previous Step',
|
prevStep: 'Previous Step',
|
||||||
exportSuccess: 'Export successful',
|
exportSuccess: 'Export successful',
|
||||||
recommend: 'Recommend',
|
recommend: 'Recommend',
|
||||||
|
default: 'Default',
|
||||||
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
||||||
imageSquareRequired: 'Please upload a square image',
|
imageSquareRequired: 'Please upload a square image',
|
||||||
nameInvalid: 'Name cannot start or end with a space',
|
nameInvalid: 'Name cannot start or end with a space',
|
||||||
@@ -2616,6 +2616,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
updated_at: 'Updated At',
|
updated_at: 'Updated At',
|
||||||
entityTypes: 'Entity Types',
|
entityTypes: 'Entity Types',
|
||||||
|
|
||||||
|
classSearchPlaceholder: 'Search types',
|
||||||
addClass: 'Add Type',
|
addClass: 'Add Type',
|
||||||
class_name: 'Type Name',
|
class_name: 'Type Name',
|
||||||
class_description: 'Type Definition',
|
class_description: 'Type Definition',
|
||||||
|
|||||||
@@ -1020,7 +1020,6 @@ export const zh = {
|
|||||||
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
|
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
|
||||||
publicApiCannotRefreshToken: '公共接口不能刷新token',
|
publicApiCannotRefreshToken: '公共接口不能刷新token',
|
||||||
refreshTokenNotExist: '刷新token不存在',
|
refreshTokenNotExist: '刷新token不存在',
|
||||||
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
|
|
||||||
reset: '重置',
|
reset: '重置',
|
||||||
refresh: '刷新',
|
refresh: '刷新',
|
||||||
return: '返回',
|
return: '返回',
|
||||||
@@ -1034,6 +1033,7 @@ export const zh = {
|
|||||||
prevStep: '上一步',
|
prevStep: '上一步',
|
||||||
exportSuccess: '导出成功',
|
exportSuccess: '导出成功',
|
||||||
recommend: '推荐',
|
recommend: '推荐',
|
||||||
|
default: '默认',
|
||||||
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
||||||
imageSquareRequired: '请上传正方形比例图片',
|
imageSquareRequired: '请上传正方形比例图片',
|
||||||
nameInvalid: '不能是空格开头或结尾',
|
nameInvalid: '不能是空格开头或结尾',
|
||||||
@@ -2617,6 +2617,7 @@ export const zh = {
|
|||||||
updated_at: '更新时间',
|
updated_at: '更新时间',
|
||||||
entityTypes: '实体类型',
|
entityTypes: '实体类型',
|
||||||
|
|
||||||
|
classSearchPlaceholder: '搜索类型',
|
||||||
addClass: '添加类型',
|
addClass: '添加类型',
|
||||||
class_name: '类型名称',
|
class_name: '类型名称',
|
||||||
class_description: '类型定义',
|
class_description: '类型定义',
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-02 16:35:15
|
* @Date: 2026-02-02 16:35:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-02 16:35:15
|
* @Last Modified time: 2026-03-06 10:39:00
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* HTTP Request Utility Module
|
* HTTP Request Utility Module
|
||||||
@@ -183,7 +183,7 @@ service.interceptors.response.use(
|
|||||||
msg = msg || i18n.t('common.serverError');
|
msg = msg || i18n.t('common.serverError');
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') {
|
if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
|
||||||
msg = i18n.t(`common.${msg}`)
|
msg = i18n.t(`common.${msg}`)
|
||||||
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
||||||
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:27:39
|
* @Date: 2026-02-03 16:27:39
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 18:51:20
|
* @Last Modified time: 2026-03-05 17:03:46
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Chat debugging component for application testing
|
* Chat debugging component for application testing
|
||||||
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
|||||||
.then(() => {
|
.then(() => {
|
||||||
const message = msg
|
const message = msg
|
||||||
if (!message?.trim()) return
|
if (!message?.trim()) return
|
||||||
|
// Validate required variables before sending
|
||||||
|
let isCanSend = true
|
||||||
|
const params: Record<string, any> = {}
|
||||||
|
if (chatVariables && chatVariables.length > 0) {
|
||||||
|
const needRequired: string[] = []
|
||||||
|
chatVariables.forEach(vo => {
|
||||||
|
params[vo.name] = vo.value
|
||||||
|
|
||||||
|
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||||
|
isCanSend = false
|
||||||
|
needRequired.push(vo.name)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (needRequired.length) {
|
||||||
|
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!isCanSend) {
|
||||||
|
setLoading(false)
|
||||||
|
setCompareLoading(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
addUserMessage(message, fileList)
|
addUserMessage(message, fileList)
|
||||||
setMessage(message)
|
setMessage(message)
|
||||||
@@ -198,29 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
|||||||
};
|
};
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
// Validate required variables before sending
|
|
||||||
let isCanSend = true
|
|
||||||
const params: Record<string, any> = {}
|
|
||||||
if (chatVariables && chatVariables.length > 0) {
|
|
||||||
const needRequired: string[] = []
|
|
||||||
chatVariables.forEach(vo => {
|
|
||||||
params[vo.name] = vo.value
|
|
||||||
|
|
||||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
|
||||||
isCanSend = false
|
|
||||||
needRequired.push(vo.name)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if (needRequired.length) {
|
|
||||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!isCanSend) {
|
|
||||||
setLoading(false)
|
|
||||||
setCompareLoading(false)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
runCompare(data.app_id, {
|
runCompare(data.app_id, {
|
||||||
message,
|
message,
|
||||||
files: fileList.map(file => {
|
files: fileList.map(file => {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:09:42
|
* @Date: 2026-02-06 21:09:42
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-04 18:54:47
|
* @Last Modified time: 2026-03-05 15:09:22
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* File Upload Component
|
* File Upload Component
|
||||||
@@ -206,7 +206,7 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
|||||||
*/
|
*/
|
||||||
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
|
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
|
||||||
newFileList.map(file => {
|
newFileList.map(file => {
|
||||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type
|
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
||||||
file.type = type
|
file.type = type
|
||||||
})
|
})
|
||||||
setFileList(newFileList);
|
setFileList(newFileList);
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
} from '@/api/knowledgeBase'
|
} from '@/api/knowledgeBase'
|
||||||
import RbModal from '@/components/RbModal'
|
import RbModal from '@/components/RbModal'
|
||||||
import SliderInput from '@/components/SliderInput'
|
import SliderInput from '@/components/SliderInput'
|
||||||
|
import { stringRegExp } from '@/utils/validator'
|
||||||
const { TextArea } = Input;
|
const { TextArea } = Input;
|
||||||
const { confirm } = Modal
|
const { confirm } = Modal
|
||||||
|
|
||||||
@@ -519,12 +520,16 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
|||||||
<Form.Item
|
<Form.Item
|
||||||
name="name"
|
name="name"
|
||||||
label={t('knowledgeBase.createForm.name')}
|
label={t('knowledgeBase.createForm.name')}
|
||||||
rules={[{ required: true, message: t('knowledgeBase.createForm.nameRequired') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('knowledgeBase.createForm.nameRequired') },
|
||||||
|
{ max: 50 },
|
||||||
|
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('knowledgeBase.createForm.name')} />
|
<Input placeholder={t('knowledgeBase.createForm.name')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
)}
|
)}
|
||||||
<Form.Item name="description" label={t('knowledgeBase.createForm.description')}>
|
<Form.Item name="description" label={t('knowledgeBase.createForm.description')} rules={[{ max: 500 }]}>
|
||||||
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
|
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 17:33:15
|
* @Date: 2026-02-03 17:33:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-03 17:33:15
|
* @Last Modified time: 2026-03-05 16:28:58
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Memory Management Page
|
* Memory Management Page
|
||||||
@@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => {
|
|||||||
<List.Item key={item.config_id}>
|
<List.Item key={item.config_id}>
|
||||||
<RbCard
|
<RbCard
|
||||||
title={item.config_name}
|
title={item.config_name}
|
||||||
|
className="rb:relative"
|
||||||
>
|
>
|
||||||
|
{item.is_system_default &&
|
||||||
|
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
|
||||||
|
{t('common.default')}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
<Tooltip title={item.config_desc}>
|
<Tooltip title={item.config_desc}>
|
||||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
|
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-4.25">{item.config_desc}</div>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<RbAlert className="rb:mt-3 ">
|
<RbAlert className="rb:mt-3 ">
|
||||||
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 17:33:01
|
* @Date: 2026-02-03 17:33:01
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-03 17:33:24
|
* @Last Modified time: 2026-03-05 16:33:53
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Memory management form data type
|
* Memory management form data type
|
||||||
@@ -42,6 +42,7 @@ export interface Memory {
|
|||||||
workspace_id: string;
|
workspace_id: string;
|
||||||
scene_id: string;
|
scene_id: string;
|
||||||
scene_name: string;
|
scene_name: string;
|
||||||
|
is_system_default: boolean;
|
||||||
[key: string]: string | number | boolean;
|
[key: string]: string | number | boolean;
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:24
|
* @Date: 2026-02-03 14:10:24
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-09 18:02:13
|
* @Last Modified time: 2026-03-06 11:25:59
|
||||||
*/
|
*/
|
||||||
import { type FC, type ReactNode } from 'react';
|
import { type FC, type ReactNode } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
@@ -17,7 +17,7 @@ const { Header } = Layout;
|
|||||||
*/
|
*/
|
||||||
interface ConfigHeaderProps {
|
interface ConfigHeaderProps {
|
||||||
/** Page title/name */
|
/** Page title/name */
|
||||||
name?: string;
|
name?: string | ReactNode;
|
||||||
/** Subtitle content displayed below the title */
|
/** Subtitle content displayed below the title */
|
||||||
subTitle?: ReactNode | string;
|
subTitle?: ReactNode | string;
|
||||||
/** Extra content displayed on the right side */
|
/** Extra content displayed on the right side */
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:15
|
* @Date: 2026-02-03 14:10:15
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-05 10:57:53
|
* @Last Modified time: 2026-03-06 10:56:44
|
||||||
*/
|
*/
|
||||||
import { type FC, useState, useRef, type MouseEvent } from 'react';
|
import { type FC, useState, useRef, type MouseEvent } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
@@ -144,8 +144,13 @@ const Ontology: FC = () => {
|
|||||||
title={item.scene_name}
|
title={item.scene_name}
|
||||||
extra={<Tag>{item.type_num} {t('ontology.typeCount')}</Tag>}
|
extra={<Tag>{item.type_num} {t('ontology.typeCount')}</Tag>}
|
||||||
onClick={() => handleJump(item)}
|
onClick={() => handleJump(item)}
|
||||||
className="rb:cursor-pointer"
|
className="rb:cursor-pointer rb:relative"
|
||||||
>
|
>
|
||||||
|
{item.is_system_default &&
|
||||||
|
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
|
||||||
|
{t('common.default')}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
<div
|
<div
|
||||||
className="rb:flex rb:gap-2 rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3"
|
className="rb:flex rb:gap-2 rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3"
|
||||||
>
|
>
|
||||||
@@ -176,8 +181,8 @@ const Ontology: FC = () => {
|
|||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<div className="rb:mt-4 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
<div className="rb:mt-4 rb:h-5 rb:text-[12px] rb:leading-4 rb:font-regular rb:text-[#5B6167] rb:flex rb:items-center rb:justify-end">
|
||||||
<Space size={16}>
|
{!item.is_system_default && <Space size={16}>
|
||||||
<div
|
<div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/edit.svg')] rb:hover:bg-[url('@/assets/images/edit_hover.svg')]"
|
||||||
onClick={(e) => handleEdit(item, e)}
|
onClick={(e) => handleEdit(item, e)}
|
||||||
@@ -186,7 +191,7 @@ const Ontology: FC = () => {
|
|||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={(e) => handleDelete(item, e)}
|
onClick={(e) => handleDelete(item, e)}
|
||||||
></div>
|
></div>
|
||||||
</Space>
|
</Space>}
|
||||||
</div>
|
</div>
|
||||||
</RbCard>
|
</RbCard>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:20
|
* @Date: 2026-02-03 14:10:20
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-09 17:56:35
|
* @Last Modified time: 2026-03-06 11:26:49
|
||||||
*/
|
*/
|
||||||
import { type FC, useEffect, useState, useRef } from 'react'
|
import { type FC, useEffect, useState, useRef } from 'react'
|
||||||
import { useParams } from 'react-router-dom';
|
import { useParams } from 'react-router-dom';
|
||||||
@@ -17,6 +17,7 @@ import OntologyClassModal from '../components/OntologyClassModal'
|
|||||||
import SearchInput from '@/components/SearchInput';
|
import SearchInput from '@/components/SearchInput';
|
||||||
import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
|
import OntologyClassExtractModal from '../components/OntologyClassExtractModal'
|
||||||
import BodyWrapper from '@/components/Empty/BodyWrapper'
|
import BodyWrapper from '@/components/Empty/BodyWrapper'
|
||||||
|
import Tag from '@/components/Tag'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ontology detail page component
|
* Ontology detail page component
|
||||||
@@ -99,19 +100,22 @@ const Detail: FC = () => {
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<PageHeader
|
<PageHeader
|
||||||
name={data.scene_name}
|
name={<Space>
|
||||||
|
{data.scene_name}
|
||||||
|
<Tag color="warning">{t('common.default')}</Tag>
|
||||||
|
</Space>}
|
||||||
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
subTitle={<Tooltip title={data.scene_description}><div className="rb:h-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{data.scene_description}</div></Tooltip>}
|
||||||
extra={<Space>
|
extra={data.is_system_default ? undefined : (<Space>
|
||||||
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
|
<Button type="primary" ghost className="rb:h-6! rb:px-2! rb:leading-5.5!" onClick={handleAdd}>+ {t('ontology.addClass')}</Button>
|
||||||
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
|
<Button className="rb:h-6! rb:px-2! rb:leading-5.5!" type="primary" onClick={handleExtract}>+ {t('ontology.extract')}</Button>
|
||||||
</Space>}
|
</Space>)}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
|
<div className="rb:h-[calc(100vh-64px)] rb:overflow-y-auto rb:py-3 rb:px-4">
|
||||||
<Row gutter={16} className="rb:mb-4">
|
<Row gutter={16} className="rb:mb-4">
|
||||||
<Col span={6} offset={18}>
|
<Col span={6} offset={18}>
|
||||||
<SearchInput
|
<SearchInput
|
||||||
placeholder={t('ontology.searchPlaceholder')}
|
placeholder={t('ontology.classSearchPlaceholder')}
|
||||||
onSearch={(value) => setQuery({ class_name: value })}
|
onSearch={(value) => setQuery({ class_name: value })}
|
||||||
className="rb:w-full!"
|
className="rb:w-full!"
|
||||||
/>
|
/>
|
||||||
@@ -123,10 +127,10 @@ const Detail: FC = () => {
|
|||||||
<Col key={item.class_id} span={6}>
|
<Col key={item.class_id} span={6}>
|
||||||
<RbCard
|
<RbCard
|
||||||
title={item.class_name}
|
title={item.class_name}
|
||||||
extra={<div
|
extra={data.is_system_default ? undefined : (<div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={() => handleDelete(item)}
|
onClick={() => handleDelete(item)}
|
||||||
></div>}
|
></div>)}
|
||||||
className="rb:bg-transparent!"
|
className="rb:bg-transparent!"
|
||||||
>
|
>
|
||||||
<Tooltip title={item.class_description}>
|
<Tooltip title={item.class_description}>
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
/*
|
/*
|
||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:10:10
|
* @Date: 2026-02-03 14:10:10
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-03 14:10:10
|
* @Last Modified time: 2026-03-06 10:55:23
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Query parameters for ontology list pagination and filtering
|
* Query parameters for ontology list pagination and filtering
|
||||||
@@ -38,6 +38,8 @@ export interface OntologyItem {
|
|||||||
updated_at: number;
|
updated_at: number;
|
||||||
/** Total count of classes in the scene */
|
/** Total count of classes in the scene */
|
||||||
classes_count: number;
|
classes_count: number;
|
||||||
|
/** Whether this is the system default configuration */
|
||||||
|
is_system_default: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -92,6 +94,7 @@ export interface OntologyClassData {
|
|||||||
scene_description: string;
|
scene_description: string;
|
||||||
/** Array of class items */
|
/** Array of class items */
|
||||||
items: OntologyClassItem[];
|
items: OntologyClassItem[];
|
||||||
|
is_system_default: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Reference in New Issue
Block a user