Compare commits
187 Commits
revert-218
...
v0.2.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79ab929fb0 | ||
|
|
eab7225d83 | ||
|
|
1b853aa893 | ||
|
|
0159fdf149 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
a651ae6ed4 | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
75b3ea1f05 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
a53be31765 | ||
|
|
4475be51cc | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
e1f5607836 | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
9de6b4f151 | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
bcc8b7ce3c | ||
|
|
1c7fe6d134 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
c93bcb8678 | ||
|
|
98b2da9123 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
5a0d3df689 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
6ac10a8297 | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc |
@@ -3,9 +3,14 @@ import platform
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB 0)
|
# broker: 任务队列(使用 Redis DB 0)
|
||||||
# backend: 结果存储(使用 Redis DB 10)
|
# backend: 结果存储(使用 Redis DB 10)
|
||||||
@@ -63,15 +68,20 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||||
|
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -79,40 +89,40 @@ celery_app.conf.update(
|
|||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||||
|
|
||||||
# 构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
# beat_schedule_config = {
|
||||||
"run-workspace-reflection": {
|
# "run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
# "task": "app.tasks.workspace_reflection_task",
|
||||||
"schedule": workspace_reflection_schedule,
|
# "schedule": workspace_reflection_schedule,
|
||||||
"args": (),
|
# "args": (),
|
||||||
},
|
# },
|
||||||
"regenerate-memory-cache": {
|
# "regenerate-memory-cache": {
|
||||||
"task": "app.tasks.regenerate_memory_cache",
|
# "task": "app.tasks.regenerate_memory_cache",
|
||||||
"schedule": memory_cache_regeneration_schedule,
|
# "schedule": memory_cache_regeneration_schedule,
|
||||||
"args": (),
|
# "args": (),
|
||||||
},
|
# },
|
||||||
"run-forgetting-cycle": {
|
# "run-forgetting-cycle": {
|
||||||
"task": "app.tasks.run_forgetting_cycle_task",
|
# "task": "app.tasks.run_forgetting_cycle_task",
|
||||||
"schedule": forgetting_cycle_schedule,
|
# "schedule": forgetting_cycle_schedule,
|
||||||
"kwargs": {
|
# "kwargs": {
|
||||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||||
},
|
# },
|
||||||
},
|
# },
|
||||||
}
|
# }
|
||||||
|
|
||||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||||
if settings.DEFAULT_WORKSPACE_ID:
|
# if settings.DEFAULT_WORKSPACE_ID:
|
||||||
beat_schedule_config["write-total-memory"] = {
|
# beat_schedule_config["write-total-memory"] = {
|
||||||
"task": "app.controllers.memory_storage_controller.search_all",
|
# "task": "app.controllers.memory_storage_controller.search_all",
|
||||||
"schedule": memory_increment_schedule,
|
# "schedule": memory_increment_schedule,
|
||||||
"kwargs": {
|
# "kwargs": {
|
||||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||||
},
|
# },
|
||||||
}
|
# }
|
||||||
|
|
||||||
celery_app.conf.beat_schedule = beat_schedule_config
|
# celery_app.conf.beat_schedule = beat_schedule_config
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from . import (
|
|||||||
home_page_controller,
|
home_page_controller,
|
||||||
memory_perceptual_controller,
|
memory_perceptual_controller,
|
||||||
memory_working_controller,
|
memory_working_controller,
|
||||||
|
ontology_controller,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -90,5 +91,6 @@ manager_router.include_router(implicit_memory_controller.router)
|
|||||||
manager_router.include_router(memory_perceptual_controller.router)
|
manager_router.include_router(memory_perceptual_controller.router)
|
||||||
manager_router.include_router(memory_working_controller.router)
|
manager_router.include_router(memory_working_controller.router)
|
||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
|
manager_router.include_router(ontology_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ Routes:
|
|||||||
GET /memory/config/emotion - 获取情绪引擎配置
|
GET /memory/config/emotion - 获取情绪引擎配置
|
||||||
POST /memory/config/emotion - 更新情绪引擎配置
|
POST /memory/config/emotion - 更新情绪引擎配置
|
||||||
"""
|
"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
from app.services.emotion_config_service import EmotionConfigService
|
from app.services.emotion_config_service import EmotionConfigService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -37,7 +39,7 @@ class EmotionConfigQuery(BaseModel):
|
|||||||
|
|
||||||
class EmotionConfigUpdate(BaseModel):
|
class EmotionConfigUpdate(BaseModel):
|
||||||
"""情绪配置更新请求模型"""
|
"""情绪配置更新请求模型"""
|
||||||
config_id: UUID = Field(..., description="配置ID")
|
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||||
@@ -46,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
def get_emotion_config(
|
def get_emotion_config(
|
||||||
config_id: UUID = Query(..., description="配置ID"),
|
config_id: UUID|int = Query(..., description="配置ID"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -79,7 +81,7 @@ def get_emotion_config(
|
|||||||
f"用户 {current_user.username} 请求获取情绪配置",
|
f"用户 {current_user.username} 请求获取情绪配置",
|
||||||
extra={"config_id": config_id}
|
extra={"config_id": config_id}
|
||||||
)
|
)
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 初始化服务
|
# 初始化服务
|
||||||
config_service = EmotionConfigService(db)
|
config_service = EmotionConfigService(db)
|
||||||
|
|
||||||
@@ -158,6 +160,7 @@ def update_emotion_config(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
config.config_id=resolve_config_id(config.config_id, db)
|
||||||
try:
|
try:
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求更新情绪配置",
|
f"用户 {current_user.username} 请求更新情绪配置",
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -84,7 +84,8 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id((config_id), db)
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
@@ -129,7 +130,7 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
async def read_forgetting_config(
|
async def read_forgetting_config(
|
||||||
config_id: UUID,
|
config_id: UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -158,6 +159,7 @@ async def read_forgetting_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 调用服务层读取配置
|
# 调用服务层读取配置
|
||||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||||
|
|
||||||
@@ -195,6 +197,8 @@ async def update_forgetting_config(
|
|||||||
ApiResponse: 包含更新结果的响应
|
ApiResponse: 包含更新结果的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||||
|
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
@@ -255,12 +259,10 @@ async def get_forgetting_stats(
|
|||||||
ApiResponse: 包含统计信息的响应
|
ApiResponse: 包含统计信息的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
# 如果提供了 end_user_id,通过它获取 config_id
|
# 如果提供了 end_user_id,通过它获取 config_id
|
||||||
config_id = None
|
config_id = None
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
@@ -269,6 +271,7 @@ async def get_forgetting_stats(
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
@@ -325,7 +328,7 @@ async def get_forgetting_curve(
|
|||||||
ApiResponse: 包含遗忘曲线数据的响应
|
ApiResponse: 包含遗忘曲线数据的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
request.config_id = resolve_config_id((request.config_id), db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -43,12 +45,12 @@ async def save_reflection_config(
|
|||||||
"""Save reflection configuration to data_comfig table"""
|
"""Save reflection configuration to data_comfig table"""
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
if not config_id:
|
if not config_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="缺少必需参数: config_id"
|
detail="缺少必需参数: config_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
@@ -99,7 +101,7 @@ async def start_workspace_reflection(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""启动工作空间中所有匹配应用的反思功能"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
reflection_service = MemoryReflectionService(db)
|
reflection_service = MemoryReflectionService(db)
|
||||||
|
|
||||||
@@ -108,42 +110,55 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
service = WorkspaceAppService(db)
|
service = WorkspaceAppService(db)
|
||||||
result = service.get_workspace_apps_detailed(workspace_id)
|
result = service.get_workspace_apps_detailed(workspace_id)
|
||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
if data['memory_configs'] == []:
|
# 跳过没有配置的应用
|
||||||
|
if not data['memory_configs']:
|
||||||
|
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
releases = data['releases']
|
releases = data['releases']
|
||||||
memory_configs = data['memory_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
for base, config, user in zip(releases, memory_configs, end_users):
|
# 为每个配置和用户组合执行反思
|
||||||
# 安全地转换为整数,处理空字符串和None的情况
|
for config in memory_configs:
|
||||||
print(base['config'])
|
config_id_str = str(config['config_id'])
|
||||||
try:
|
|
||||||
base_config = int(base['config']) if base['config'] else 0
|
# 找到匹配此配置的所有release
|
||||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||||
except (ValueError, TypeError):
|
|
||||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
if not matching_releases:
|
||||||
|
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
# 为每个用户执行反思
|
||||||
# 调用反思服务
|
for user in end_users:
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||||
|
|
||||||
reflection_result = await reflection_service.start_text_reflection(
|
try:
|
||||||
config_data=config,
|
reflection_result = await reflection_service.start_text_reflection(
|
||||||
end_user_id=user['id']
|
config_data=config,
|
||||||
)
|
end_user_id=user['id']
|
||||||
|
)
|
||||||
reflection_results.append({
|
|
||||||
"app_id": base['app_id'],
|
reflection_results.append({
|
||||||
"config_id": config['config_id'],
|
"app_id": data['id'],
|
||||||
"end_user_id": user['id'],
|
"config_id": config_id_str,
|
||||||
"reflection_result": reflection_result
|
"end_user_id": user['id'],
|
||||||
})
|
"reflection_result": reflection_result
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||||
|
reflection_results.append({
|
||||||
|
"app_id": data['id'],
|
||||||
|
"config_id": config_id_str,
|
||||||
|
"end_user_id": user['id'],
|
||||||
|
"reflection_result": {
|
||||||
|
"status": "错误",
|
||||||
|
"message": f"反思失败: {str(e)}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return success(data=reflection_results, msg="反思配置成功")
|
return success(data=reflection_results, msg="反思配置成功")
|
||||||
|
|
||||||
@@ -157,17 +172,20 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
@router.get("/reflection/configs")
|
@router.get("/reflection/configs")
|
||||||
async def start_reflection_configs(
|
async def start_reflection_configs(
|
||||||
config_id: uuid.UUID,
|
config_id: uuid.UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
try:
|
try:
|
||||||
|
config_id=resolve_config_id(config_id,db)
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
|
memory_config_id = resolve_config_id(result.config_id, db)
|
||||||
# 构建返回数据
|
# 构建返回数据
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": result.config_id,
|
"config_id": memory_config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
"reflection_period_in_hours": result.iteration_period,
|
"reflection_period_in_hours": result.iteration_period,
|
||||||
"reflexion_range": result.reflexion_range,
|
"reflexion_range": result.reflexion_range,
|
||||||
@@ -192,7 +210,7 @@ async def start_reflection_configs(
|
|||||||
|
|
||||||
@router.get("/reflection/run")
|
@router.get("/reflection/run")
|
||||||
async def reflection_run(
|
async def reflection_run(
|
||||||
config_id: UUID,
|
config_id: UUID|int,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -200,7 +218,7 @@ async def reflection_run(
|
|||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""Activate the reflection function for all matching applications in the workspace"""
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 使用MemoryConfigRepository查询反思配置
|
# 使用MemoryConfigRepository查询反思配置
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ from fastapi import APIRouter, Depends
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# Get API logger
|
# Get API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -141,7 +143,6 @@ def create_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||||
@@ -161,12 +162,12 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID,
|
config_id: UUID|int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
config_id=resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -188,12 +189,17 @@ def update_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 校验至少有一个字段需要更新
|
||||||
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
@@ -211,7 +217,7 @@ def update_config_extracted(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||||
@@ -233,12 +239,12 @@ def update_config_extracted(
|
|||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: UUID,
|
config_id: UUID | int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||||
@@ -286,6 +292,7 @@ async def pilot_run(
|
|||||||
f"Pilot run requested: config_id={payload.config_id}, "
|
f"Pilot run requested: config_id={payload.config_id}, "
|
||||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||||
)
|
)
|
||||||
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
svc.pilot_run_stream(payload),
|
svc.pilot_run_stream(payload),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories.model_repository import ModelConfigRepository
|
from app.repositories.model_repository import ModelConfigRepository
|
||||||
from app.schemas import model_schema
|
from app.schemas import model_schema
|
||||||
@@ -31,7 +31,12 @@ def get_model_types():
|
|||||||
|
|
||||||
@router.get("/provider", response_model=ApiResponse)
|
@router.get("/provider", response_model=ApiResponse)
|
||||||
def get_model_providers():
|
def get_model_providers():
|
||||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
return success(msg="获取模型提供商成功", data=providers)
|
||||||
|
|
||||||
|
@router.get("/strategy", response_model=ApiResponse)
|
||||||
|
def get_model_strategies():
|
||||||
|
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
@@ -91,7 +96,7 @@ def get_model_list(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/new", response_model=ApiResponse)
|
@router.get("/new", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list_new(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
@@ -147,7 +152,7 @@ def get_model_plaza_list(
|
|||||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||||
is_deprecated: Optional[bool] = Query(False, description="是否弃用"),
|
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
@@ -198,6 +203,10 @@ def update_model_base(
|
|||||||
):
|
):
|
||||||
"""更新基础模型"""
|
"""更新基础模型"""
|
||||||
|
|
||||||
|
# 不允许更改type类型
|
||||||
|
if data.type is not None or data.provider is not None:
|
||||||
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||||
|
|
||||||
@@ -318,6 +327,8 @@ async def update_composite_model(
|
|||||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if model_data.type is not None:
|
||||||
|
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||||
|
|
||||||
@@ -457,11 +468,13 @@ async def create_model_api_key_by_provider(
|
|||||||
priority=api_key_data.priority,
|
priority=api_key_data.priority,
|
||||||
model_config_ids=model_config_ids
|
model_config_ids=model_config_ids
|
||||||
)
|
)
|
||||||
created_keys = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||||
|
|
||||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||||
result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||||
return success(data=result_list, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||||
|
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||||
|
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
1005
api/app/controllers/ontology_controller.py
Normal file
1005
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
611
api/app/controllers/ontology_secondary_routes.py
Normal file
611
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""本体场景和类型路由(续)
|
||||||
|
|
||||||
|
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import fail, success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.ontology_schemas import (
|
||||||
|
SceneResponse,
|
||||||
|
SceneListResponse,
|
||||||
|
PaginationInfo,
|
||||||
|
ClassCreateRequest,
|
||||||
|
ClassUpdateRequest,
|
||||||
|
ClassResponse,
|
||||||
|
ClassListResponse,
|
||||||
|
ClassBatchCreateResponse,
|
||||||
|
)
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services.ontology_service import OntologyService
|
||||||
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||||
|
"""获取OntologyService实例(不需要LLM)
|
||||||
|
|
||||||
|
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||||
|
"""
|
||||||
|
dummy_config = RedBearModelConfig(
|
||||||
|
model_name="dummy",
|
||||||
|
provider="openai",
|
||||||
|
api_key="dummy",
|
||||||
|
base_url="https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
llm_client = OpenAIClient(model_config=dummy_config)
|
||||||
|
return OntologyService(llm_client=llm_client, db=db)
|
||||||
|
|
||||||
|
|
||||||
|
# 这些函数将被导入到主Controller中
|
||||||
|
|
||||||
|
async def scenes_handler(
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
scene_name: Optional[str] = None,
|
||||||
|
page: Optional[int] = None,
|
||||||
|
page_size: Optional[int] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||||
|
|
||||||
|
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||||
|
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
|
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||||
|
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||||
|
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||||
|
db: 数据库会话
|
||||||
|
current_user: 当前用户
|
||||||
|
"""
|
||||||
|
operation = "search" if scene_name else "list"
|
||||||
|
api_logger.info(
|
||||||
|
f"Scene {operation} requested by user {current_user.id}, "
|
||||||
|
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 确定工作空间ID
|
||||||
|
if workspace_id:
|
||||||
|
try:
|
||||||
|
ws_uuid = UUID(workspace_id)
|
||||||
|
except ValueError:
|
||||||
|
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||||
|
else:
|
||||||
|
ws_uuid = current_user.current_workspace_id
|
||||||
|
if not ws_uuid:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 根据是否提供 scene_name 决定查询方式
|
||||||
|
if scene_name and scene_name.strip():
|
||||||
|
# 验证分页参数(模糊搜索也支持分页)
|
||||||
|
if page is not None and page < 1:
|
||||||
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
|
if page_size is not None and page_size < 1:
|
||||||
|
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
|
# 如果只提供了page或page_size中的一个,返回错误
|
||||||
|
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||||
|
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
|
# 模糊搜索场景(支持分页)
|
||||||
|
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||||
|
total = len(scenes)
|
||||||
|
|
||||||
|
# 如果提供了分页参数,进行分页处理
|
||||||
|
if page is not None and page_size is not None:
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = start_idx + page_size
|
||||||
|
scenes = scenes[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
items = []
|
||||||
|
for scene in scenes:
|
||||||
|
# 获取前3个class_name作为entity_type
|
||||||
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
|
# 动态计算 type_num
|
||||||
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
|
items.append(SceneResponse(
|
||||||
|
scene_id=scene.scene_id,
|
||||||
|
scene_name=scene.scene_name,
|
||||||
|
scene_description=scene.scene_description,
|
||||||
|
type_num=type_num,
|
||||||
|
entity_type=entity_type,
|
||||||
|
workspace_id=scene.workspace_id,
|
||||||
|
created_at=scene.created_at,
|
||||||
|
updated_at=scene.updated_at,
|
||||||
|
classes_count=type_num
|
||||||
|
))
|
||||||
|
|
||||||
|
# 构建响应(包含分页信息)
|
||||||
|
if page is not None and page_size is not None:
|
||||||
|
# 计算是否有下一页
|
||||||
|
hasnext = (page * page_size) < total
|
||||||
|
|
||||||
|
pagination_info = PaginationInfo(
|
||||||
|
page=page,
|
||||||
|
pagesize=page_size,
|
||||||
|
total=total,
|
||||||
|
hasnext=hasnext
|
||||||
|
)
|
||||||
|
response = SceneListResponse(items=items, page=pagination_info)
|
||||||
|
else:
|
||||||
|
response = SceneListResponse(items=items)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||||
|
f"in workspace {ws_uuid}, total={total}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 获取所有场景(支持分页)
|
||||||
|
# 验证分页参数
|
||||||
|
if page is not None and page < 1:
|
||||||
|
api_logger.warning(f"Invalid page number: {page}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||||
|
|
||||||
|
if page_size is not None and page_size < 1:
|
||||||
|
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||||
|
|
||||||
|
# 如果只提供了page或page_size中的一个,返回错误
|
||||||
|
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||||
|
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||||
|
|
||||||
|
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
items = []
|
||||||
|
for scene in scenes:
|
||||||
|
# 获取前3个class_name作为entity_type
|
||||||
|
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||||
|
# 动态计算 type_num
|
||||||
|
type_num = len(scene.classes) if scene.classes else 0
|
||||||
|
|
||||||
|
items.append(SceneResponse(
|
||||||
|
scene_id=scene.scene_id,
|
||||||
|
scene_name=scene.scene_name,
|
||||||
|
scene_description=scene.scene_description,
|
||||||
|
type_num=type_num,
|
||||||
|
entity_type=entity_type,
|
||||||
|
workspace_id=scene.workspace_id,
|
||||||
|
created_at=scene.created_at,
|
||||||
|
updated_at=scene.updated_at,
|
||||||
|
classes_count=type_num
|
||||||
|
))
|
||||||
|
|
||||||
|
# 构建响应(包含分页信息)
|
||||||
|
if page is not None and page_size is not None:
|
||||||
|
# 计算是否有下一页
|
||||||
|
hasnext = (page * page_size) < total
|
||||||
|
|
||||||
|
pagination_info = PaginationInfo(
|
||||||
|
page=page,
|
||||||
|
pagesize=page_size,
|
||||||
|
total=total,
|
||||||
|
hasnext=hasnext
|
||||||
|
)
|
||||||
|
response = SceneListResponse(items=items, page=pagination_info)
|
||||||
|
else:
|
||||||
|
response = SceneListResponse(items=items)
|
||||||
|
|
||||||
|
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 本体类型管理接口 ====================
|
||||||
|
|
||||||
|
async def create_class_handler(
|
||||||
|
request: ClassCreateRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||||
|
|
||||||
|
# 根据列表长度判断是单个还是批量
|
||||||
|
count = len(request.classes)
|
||||||
|
mode = "single" if count == 1 else "batch"
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||||
|
f"scene_id={request.scene_id}, count={count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取当前工作空间ID
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 准备类型数据
|
||||||
|
classes_data = [
|
||||||
|
{
|
||||||
|
"class_name": item.class_name,
|
||||||
|
"class_description": item.class_description
|
||||||
|
}
|
||||||
|
for item in request.classes
|
||||||
|
]
|
||||||
|
|
||||||
|
if count == 1:
|
||||||
|
# 单个创建
|
||||||
|
class_data = classes_data[0]
|
||||||
|
ontology_class = service.create_class(
|
||||||
|
scene_id=request.scene_id,
|
||||||
|
class_name=class_data["class_name"],
|
||||||
|
class_description=class_data["class_description"],
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建单个响应
|
||||||
|
response = ClassResponse(
|
||||||
|
class_id=ontology_class.class_id,
|
||||||
|
class_name=ontology_class.class_name,
|
||||||
|
class_description=ontology_class.class_description,
|
||||||
|
scene_id=ontology_class.scene_id,
|
||||||
|
created_at=ontology_class.created_at,
|
||||||
|
updated_at=ontology_class.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 批量创建
|
||||||
|
created_classes, errors = service.create_classes_batch(
|
||||||
|
scene_id=request.scene_id,
|
||||||
|
classes=classes_data,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建批量响应
|
||||||
|
items = []
|
||||||
|
for ontology_class in created_classes:
|
||||||
|
items.append(ClassResponse(
|
||||||
|
class_id=ontology_class.class_id,
|
||||||
|
class_name=ontology_class.class_name,
|
||||||
|
class_description=ontology_class.class_description,
|
||||||
|
scene_id=ontology_class.scene_id,
|
||||||
|
created_at=ontology_class.created_at,
|
||||||
|
updated_at=ontology_class.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
response = ClassBatchCreateResponse(
|
||||||
|
total=len(classes_data),
|
||||||
|
success_count=len(created_classes),
|
||||||
|
failed_count=len(errors),
|
||||||
|
items=items,
|
||||||
|
errors=errors if errors else None
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Batch class creation completed: "
|
||||||
|
f"success={len(created_classes)}, failed={len(errors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
async def update_class_handler(
|
||||||
|
class_id: str,
|
||||||
|
request: ClassUpdateRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新本体类型"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Class update requested by user {current_user.id}, "
|
||||||
|
f"class_id={class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证UUID格式
|
||||||
|
try:
|
||||||
|
class_uuid = UUID(class_id)
|
||||||
|
except ValueError:
|
||||||
|
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||||
|
|
||||||
|
# 获取当前工作空间ID
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 更新类型
|
||||||
|
ontology_class = service.update_class(
|
||||||
|
class_id=class_uuid,
|
||||||
|
class_name=request.class_name,
|
||||||
|
class_description=request.class_description,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response = ClassResponse(
|
||||||
|
class_id=ontology_class.class_id,
|
||||||
|
class_name=ontology_class.class_name,
|
||||||
|
class_description=ontology_class.class_description,
|
||||||
|
scene_id=ontology_class.scene_id,
|
||||||
|
created_at=ontology_class.created_at,
|
||||||
|
updated_at=ontology_class.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Class updated successfully: {class_id}")
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_class_handler(
|
||||||
|
class_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除本体类型"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Class deletion requested by user {current_user.id}, "
|
||||||
|
f"class_id={class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证UUID格式
|
||||||
|
try:
|
||||||
|
class_uuid = UUID(class_id)
|
||||||
|
except ValueError:
|
||||||
|
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||||
|
|
||||||
|
# 获取当前工作空间ID
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 删除类型
|
||||||
|
success_flag = service.delete_class(
|
||||||
|
class_id=class_uuid,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||||
|
|
||||||
|
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_class_handler(
|
||||||
|
class_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取单个本体类型"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get class requested by user {current_user.id}, "
|
||||||
|
f"class_id={class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证UUID格式
|
||||||
|
try:
|
||||||
|
class_uuid = UUID(class_id)
|
||||||
|
except ValueError:
|
||||||
|
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||||
|
|
||||||
|
# 获取当前工作空间ID
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 获取类型(会抛出ValueError如果不存在)
|
||||||
|
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response = ClassResponse(
|
||||||
|
class_id=ontology_class.class_id,
|
||||||
|
class_name=ontology_class.class_name,
|
||||||
|
class_description=ontology_class.class_description,
|
||||||
|
scene_id=ontology_class.scene_id,
|
||||||
|
created_at=ontology_class.created_at,
|
||||||
|
updated_at=ontology_class.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# 类型不存在或无权限访问
|
||||||
|
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||||
|
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
async def classes_handler(
|
||||||
|
scene_id: str,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取类型列表(支持模糊搜索和全量查询)
|
||||||
|
|
||||||
|
当提供 class_name 参数时,进行模糊搜索;
|
||||||
|
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID(必填)
|
||||||
|
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||||
|
db: 数据库会话
|
||||||
|
current_user: 当前用户
|
||||||
|
"""
|
||||||
|
operation = "search" if class_name else "list"
|
||||||
|
api_logger.info(
|
||||||
|
f"Class {operation} requested by user {current_user.id}, "
|
||||||
|
f"keyword={class_name}, scene_id={scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 验证UUID格式
|
||||||
|
try:
|
||||||
|
scene_uuid = UUID(scene_id)
|
||||||
|
except ValueError:
|
||||||
|
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||||
|
|
||||||
|
# 获取当前工作空间ID
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
if not workspace_id:
|
||||||
|
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||||
|
|
||||||
|
# 创建Service
|
||||||
|
service = _get_dummy_ontology_service(db)
|
||||||
|
|
||||||
|
# 获取场景信息
|
||||||
|
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||||
|
if not scene:
|
||||||
|
api_logger.warning(f"Scene not found: {scene_id}")
|
||||||
|
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||||
|
|
||||||
|
# 根据是否提供 class_name 决定查询方式
|
||||||
|
if class_name and class_name.strip():
|
||||||
|
# 模糊搜索类型
|
||||||
|
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||||
|
else:
|
||||||
|
# 获取所有类型
|
||||||
|
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
items = []
|
||||||
|
for ontology_class in classes:
|
||||||
|
items.append(ClassResponse(
|
||||||
|
class_id=ontology_class.class_id,
|
||||||
|
class_name=ontology_class.class_name,
|
||||||
|
class_description=ontology_class.class_description,
|
||||||
|
scene_id=ontology_class.scene_id,
|
||||||
|
created_at=ontology_class.created_at,
|
||||||
|
updated_at=ontology_class.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
response = ClassListResponse(
|
||||||
|
total=len(items),
|
||||||
|
scene_id=scene_uuid,
|
||||||
|
scene_name=scene.scene_name,
|
||||||
|
scene_description=scene.scene_description,
|
||||||
|
items=items
|
||||||
|
)
|
||||||
|
|
||||||
|
if class_name:
|
||||||
|
api_logger.info(
|
||||||
|
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||||
|
f"in scene {scene_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||||
|
|
||||||
|
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||||
|
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import uuid
|
|
||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Path
|
from fastapi import APIRouter, Depends, Path
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.dependencies import get_current_user, get_db
|
from app.dependencies import get_current_user, get_db
|
||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.schemas.prompt_optimizer_schema import (
|
||||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
PromptOptMessage,
|
||||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
CreateSessionResponse,
|
||||||
|
SessionHistoryResponse,
|
||||||
|
SessionMessage,
|
||||||
|
PromptSaveRequest
|
||||||
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||||
|
|
||||||
@@ -135,3 +139,109 @@ async def get_prompt_opt(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/releases",
|
||||||
|
summary="Get prompt optimization",
|
||||||
|
response_model=ApiResponse
|
||||||
|
)
|
||||||
|
def save_prompt(
|
||||||
|
data: PromptSaveRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save a prompt release for the current tenant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||||
|
db (Session): SQLAlchemy database session, injected via dependency.
|
||||||
|
current_user: Currently authenticated user object, injected via dependency.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Standard API response containing the saved prompt release info:
|
||||||
|
- id: UUID of the prompt release
|
||||||
|
- session_id: associated session
|
||||||
|
- title: prompt title
|
||||||
|
- prompt: prompt content
|
||||||
|
- created_at: timestamp of creation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Any database or service exceptions are propagated to the global exception handler.
|
||||||
|
"""
|
||||||
|
service = PromptOptimizerService(db)
|
||||||
|
prompt_info = service.save_prompt(
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
session_id=data.session_id,
|
||||||
|
title=data.title,
|
||||||
|
prompt=data.prompt
|
||||||
|
)
|
||||||
|
return success(data=prompt_info)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/releases/{prompt_id}",
|
||||||
|
summary="Delete prompt (soft delete)",
|
||||||
|
response_model=ApiResponse
|
||||||
|
)
|
||||||
|
def delete_prompt(
|
||||||
|
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Soft delete a prompt release.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id
|
||||||
|
db (Session): Database session
|
||||||
|
current_user: Current logged-in user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Success message confirming deletion
|
||||||
|
"""
|
||||||
|
service = PromptOptimizerService(db)
|
||||||
|
service.delete_prompt(
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
prompt_id=prompt_id
|
||||||
|
)
|
||||||
|
return success(msg="Prompt deleted successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/releases/list",
|
||||||
|
summary="Get paginated list of released prompts with optional filter",
|
||||||
|
response_model=ApiResponse
|
||||||
|
)
|
||||||
|
def get_release_list(
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
keyword: str | None = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve paginated list of released prompts for the current tenant.
|
||||||
|
Optionally filter by keyword in title.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page (int): Page number (starting from 1)
|
||||||
|
page_size (int): Number of items per page (max 100)
|
||||||
|
keyword (str | None): Optional keyword to filter prompt titles
|
||||||
|
db (Session): Database session
|
||||||
|
current_user: Current logged-in user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||||
|
"""
|
||||||
|
service = PromptOptimizerService(db)
|
||||||
|
result = service.get_release_list(
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
page=max(1, page),
|
||||||
|
page_size=min(max(1, page_size), 100),
|
||||||
|
filter_keyword=keyword
|
||||||
|
)
|
||||||
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -235,11 +235,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=new_end_user.id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=web_search,
|
||||||
memory=payload.memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
@@ -268,11 +268,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=new_end_user.id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=web_search,
|
||||||
memory=payload.memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
|||||||
@@ -7,27 +7,21 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -104,7 +98,7 @@ class LangChainAgent:
|
|||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
"tool_count": len(self.tools)
|
# "tool_count": len(self.tools)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,100 +137,7 @@ class LangChainAgent:
|
|||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
|
||||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
|
||||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
|
||||||
# end_user_end=f"Term_{end_user_end}"
|
|
||||||
# print(messages)
|
|
||||||
# print(aimessages)
|
|
||||||
# session_id = store.save_session(
|
|
||||||
# userid=end_user_end,
|
|
||||||
# messages=messages,
|
|
||||||
# apply_id=end_user_end,
|
|
||||||
# end_user_id=end_user_end,
|
|
||||||
# aimessages=aimessages
|
|
||||||
# )
|
|
||||||
# store.delete_duplicate_sessions()
|
|
||||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
|
||||||
# return session_id
|
|
||||||
|
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
|
||||||
# async def term_memory_redis_read(self,end_user_end):
|
|
||||||
# end_user_end = f"Term_{end_user_end}"
|
|
||||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
|
||||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
|
||||||
# messagss_list=[]
|
|
||||||
# retrieved_content=[]
|
|
||||||
# for messages in history:
|
|
||||||
# query = messages.get("Query")
|
|
||||||
# aimessages = messages.get("Answer")
|
|
||||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
|
||||||
# retrieved_content.append({query: aimessages})
|
|
||||||
# return messagss_list,retrieved_content
|
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
|
||||||
"""
|
|
||||||
写入记忆(支持结构化消息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
user_message: 用户消息内容
|
|
||||||
ai_message: AI 回复内容
|
|
||||||
user_rag_memory_id: RAG 记忆ID
|
|
||||||
actual_end_user_id: 实际用户ID
|
|
||||||
actual_config_id: 配置ID
|
|
||||||
|
|
||||||
逻辑说明:
|
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
|
||||||
- Neo4j 模式:使用结构化消息列表
|
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
|
||||||
"""
|
|
||||||
if storage_type == "rag":
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
else:
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if user_message:
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if ai_message:
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 调用 Celery 任务,传递结构化消息列表
|
|
||||||
# 数据流:
|
|
||||||
# 1. structured_messages 传递给 write_message_task
|
|
||||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
|
||||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
|
||||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
|
||||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
|
||||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
|
||||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
|
||||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
actual_config_id, # config_id: 配置ID
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@@ -281,30 +182,6 @@ class LangChainAgent:
|
|||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
|
||||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
|
||||||
# history_term_memory = history_term_memory_result[0]
|
|
||||||
# db_for_memory = next(get_db())
|
|
||||||
# if memory_flag:
|
|
||||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
|
||||||
# history_term_memory = ';'.join(history_term_memory)
|
|
||||||
# retrieved_content = history_term_memory_result[1]
|
|
||||||
# print(retrieved_content)
|
|
||||||
# # 为长期记忆操作获取新的数据库连接
|
|
||||||
# try:
|
|
||||||
# repo = LongTermMemoryRepository(db_for_memory)
|
|
||||||
# repo.upsert(end_user_id, retrieved_content)
|
|
||||||
# logger.info(
|
|
||||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
|
||||||
# raise
|
|
||||||
# finally:
|
|
||||||
# db_for_memory.close()
|
|
||||||
|
|
||||||
# # 长期记忆写入(
|
|
||||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
|
||||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context)
|
||||||
@@ -325,17 +202,17 @@ class LangChainAgent:
|
|||||||
# 获取最后的 AI 消息
|
# 获取最后的 AI 消息
|
||||||
output_messages = result.get("messages", [])
|
output_messages = result.get("messages", [])
|
||||||
content = ""
|
content = ""
|
||||||
|
total_tokens = 0
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
content = msg.content
|
content = msg.content
|
||||||
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||||
break
|
break
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
|
||||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -343,7 +220,7 @@ class LangChainAgent:
|
|||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": total_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,25 +280,7 @@ class LangChainAgent:
|
|||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
# # TODO 乐力齐
|
|
||||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
|
||||||
# history_term_memory = history_term_memory_result[0]
|
|
||||||
# if memory_flag:
|
|
||||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
|
||||||
# history_term_memory = ';'.join(history_term_memory)
|
|
||||||
# retrieved_content = history_term_memory_result[1]
|
|
||||||
# db_for_memory = next(get_db())
|
|
||||||
# try:
|
|
||||||
# repo = LongTermMemoryRepository(db_for_memory)
|
|
||||||
# repo.upsert(end_user_id, retrieved_content)
|
|
||||||
# logger.info(
|
|
||||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
|
||||||
# # 长期记忆写入
|
|
||||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Failed to write to long term memory: {e}")
|
|
||||||
# finally:
|
|
||||||
# db_for_memory.close()
|
|
||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
@@ -437,7 +296,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content=''
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
@@ -474,12 +333,17 @@ class LangChainAgent:
|
|||||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
|
# 统计token消耗
|
||||||
|
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
|
for msg in reversed(output_messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||||
|
0) if response_meta else 0
|
||||||
|
yield total_tokens
|
||||||
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
|
||||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -157,6 +157,11 @@ class Settings:
|
|||||||
if origin.strip()
|
if origin.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Language Configuration
|
||||||
|
# Supported values: "zh" (Chinese), "en" (English)
|
||||||
|
# This controls the language used for memory summary titles and other generated content
|
||||||
|
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|||||||
@@ -0,0 +1,238 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||||
|
|
||||||
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
from app.db import get_db_context, get_db
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||||
|
actual_config_id, long_term_messages=[]):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
user_message: 用户消息内容
|
||||||
|
ai_message: AI 回复内容
|
||||||
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||||
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
|
structured_messages = long_term_messages
|
||||||
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
|
# 如果是 JSON 字符串,先解析
|
||||||
|
try:
|
||||||
|
structured_messages = json.loads(long_term_messages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: JSON 字符串格式的消息列表
|
||||||
|
str(actual_config_id), # config_id: 配置ID字符串
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
|
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
|
data = await format_parsing(result, "dict")
|
||||||
|
chunk_data = data[:scope]
|
||||||
|
if len(chunk_data)==scope:
|
||||||
|
repo.upsert(end_user_id, chunk_data)
|
||||||
|
logger.info(f'---------写入短长期-----------')
|
||||||
|
else:
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
|
long_messages = await messages_parse(long_time_data)
|
||||||
|
repo.upsert(end_user_id, long_messages)
|
||||||
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''根据窗口'''
|
||||||
|
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||||
|
'''
|
||||||
|
根据窗口获取redis数据,写入neo4j:
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
memory_config: 内存配置对象
|
||||||
|
langchain_messages:原始数据LIST
|
||||||
|
scope:窗口大小
|
||||||
|
'''
|
||||||
|
scope=scope
|
||||||
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
|
if is_end_user_id is not False:
|
||||||
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
|
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||||
|
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||||
|
is_end_user_id += 1
|
||||||
|
langchain_messages += redis_messages
|
||||||
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
|
elif int(is_end_user_id) == int(scope):
|
||||||
|
logger.info('写入长期记忆NEO4J')
|
||||||
|
formatted_messages = (redis_messages)
|
||||||
|
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||||
|
if hasattr(memory_config, 'config_id'):
|
||||||
|
config_id = memory_config.config_id
|
||||||
|
else:
|
||||||
|
config_id = memory_config
|
||||||
|
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
config_id, formatted_messages)
|
||||||
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
else:
|
||||||
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
|
||||||
|
|
||||||
|
"""根据时间"""
|
||||||
|
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||||
|
'''
|
||||||
|
根据时间获取redis数据,写入neo4j:
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
memory_config: 内存配置对象
|
||||||
|
'''
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
|
format_messages = (long_time_data)
|
||||||
|
messages=[]
|
||||||
|
memory_config=memory_config.config_id
|
||||||
|
for i in format_messages:
|
||||||
|
message=json.loads(i['Query'])
|
||||||
|
messages+= message
|
||||||
|
if format_messages!=[]:
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
|
'''聚合判断'''
|
||||||
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
|
"""
|
||||||
|
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
memory_config: 内存配置对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
|
history = await format_parsing(result)
|
||||||
|
if not result:
|
||||||
|
history = []
|
||||||
|
else:
|
||||||
|
history = await format_parsing(result)
|
||||||
|
json_schema = WriteAggregateModel.model_json_schema()
|
||||||
|
template_service = TemplateService(template_root)
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='write_aggregate_judgment.jinja2',
|
||||||
|
operation_name='aggregate_judgment',
|
||||||
|
history=history,
|
||||||
|
sentence=ori_messages,
|
||||||
|
json_schema=json_schema
|
||||||
|
)
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
factory = MemoryClientFactory(db_session)
|
||||||
|
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": system_prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=messages,
|
||||||
|
response_model=WriteAggregateModel
|
||||||
|
)
|
||||||
|
output_value = structured.output
|
||||||
|
if isinstance(output_value, list):
|
||||||
|
output_value = [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in output_value
|
||||||
|
]
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"is_same_event": structured.is_same_event,
|
||||||
|
"output": output_value
|
||||||
|
}
|
||||||
|
if not structured.is_same_event:
|
||||||
|
logger.info(result_dict)
|
||||||
|
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config.config_id, output_value)
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"is_same_event": False,
|
||||||
|
"output": ori_messages,
|
||||||
|
"messages": ori_messages,
|
||||||
|
"history": history if 'history' in locals() else [],
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
清理后的数据
|
清理后的数据
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# 需要过滤的字段列表
|
||||||
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
async def format_parsing(messages: list,type:str='string'):
|
||||||
|
"""
|
||||||
|
格式化解析消息列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
type: 返回类型 ('string' 或 'dict')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的消息列表
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
user=[]
|
||||||
|
ai=[]
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
hstory_messages = message['messages']
|
||||||
|
for history_messag in hstory_messages.strip().splitlines():
|
||||||
|
history_messag = json.loads(history_messag)
|
||||||
|
for content in history_messag:
|
||||||
|
role = content['role']
|
||||||
|
content = content['content']
|
||||||
|
if type == "string":
|
||||||
|
if role == 'human' or role=="user":
|
||||||
|
content = '用户:' + content
|
||||||
|
else:
|
||||||
|
content = 'AI:' + content
|
||||||
|
result.append(content)
|
||||||
|
if type == "dict" :
|
||||||
|
if role == 'human' or role=="user":
|
||||||
|
user.append( content)
|
||||||
|
else:
|
||||||
|
ai.append(content)
|
||||||
|
if type == "dict":
|
||||||
|
for key,values in zip(user,ai):
|
||||||
|
result.append({key:values})
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def messages_parse(messages: list | dict):
|
||||||
|
user=[]
|
||||||
|
ai=[]
|
||||||
|
database=[]
|
||||||
|
for message in messages:
|
||||||
|
Query = message['Query']
|
||||||
|
Query = json.loads(Query)
|
||||||
|
for data in Query:
|
||||||
|
role = data['role']
|
||||||
|
if role == "human":
|
||||||
|
user.append(data['content'])
|
||||||
|
if role == "ai":
|
||||||
|
ai.append(data['content'])
|
||||||
|
for key, values in zip(user, ai):
|
||||||
|
database.append({key, values})
|
||||||
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_chat_messages(user_content,ai_content):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{user_content}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"{ai_content}"
|
||||||
|
}
|
||||||
|
|
||||||
|
]
|
||||||
|
return messages
|
||||||
@@ -1,22 +1,20 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -34,14 +32,6 @@ async def make_write_graph():
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
# workflow = StateGraph(WriteState)
|
|
||||||
# workflow.add_node("content_input", content_input_write)
|
|
||||||
# workflow.add_node("save_neo4j", write_node)
|
|
||||||
# workflow.add_edge(START, "content_input")
|
|
||||||
# workflow.add_edge("content_input", "save_neo4j")
|
|
||||||
# workflow.add_edge("save_neo4j", END)
|
|
||||||
#
|
|
||||||
# graph = workflow.compile()
|
|
||||||
workflow = StateGraph(WriteState)
|
workflow = StateGraph(WriteState)
|
||||||
workflow.add_node("save_neo4j", write_node)
|
workflow.add_node("save_neo4j", write_node)
|
||||||
workflow.add_edge(START, "save_neo4j")
|
workflow.add_edge(START, "save_neo4j")
|
||||||
@@ -51,43 +41,63 @@ async def make_write_graph():
|
|||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
async def main():
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||||
"""主函数 - 运行工作流"""
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
message = "今天周一"
|
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||||
end_user_id = 'new_2025test1103' # 组ID
|
|
||||||
|
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
db_session = next(get_db())
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=17, # 改为整数
|
config_id=memory_config, # 改为整数
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
try:
|
if long_term_type=='chunk':
|
||||||
async with make_write_graph() as graph:
|
'''方案一:对话窗口6轮对话'''
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||||
# 初始状态 - 包含所有必要字段
|
if long_term_type=='time':
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
"""时间"""
|
||||||
|
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||||
# 获取节点更新信息
|
if long_term_type=='aggregate':
|
||||||
async for update_event in graph.astream(
|
"""方案三:聚合判断"""
|
||||||
initial_state,
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j'==node_name:
|
|
||||||
massages=node_data
|
|
||||||
massages=massages.get('write_result')['status']
|
|
||||||
print(massages) # | 更新数据: {node_data}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||||
asyncio.run(main())
|
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||||
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
|
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||||
|
else:
|
||||||
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
|
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||||
|
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||||
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||||
|
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||||
|
|
||||||
|
# async def main():
|
||||||
|
# """主函数 - 运行工作流"""
|
||||||
|
# langchain_messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "今天周五去爬山"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "好耶"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# ]
|
||||||
|
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||||
|
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||||
|
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# import asyncio
|
||||||
|
# asyncio.run(main())
|
||||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Pydantic models for write aggregate judgment operations."""
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MessageItem(BaseModel):
|
||||||
|
"""Individual message item in conversation."""
|
||||||
|
|
||||||
|
role: str = Field(..., description="角色:user 或 assistant")
|
||||||
|
content: str = Field(..., description="消息内容")
|
||||||
|
|
||||||
|
|
||||||
|
class WriteAggregateResponse(BaseModel):
|
||||||
|
"""Response model for aggregate judgment containing judgment result and output."""
|
||||||
|
|
||||||
|
is_same_event: bool = Field(
|
||||||
|
...,
|
||||||
|
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||||
|
)
|
||||||
|
output: Union[List[MessageItem], bool] = Field(
|
||||||
|
...,
|
||||||
|
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 为了保持向后兼容,保留旧的类名作为别名
|
||||||
|
WriteAggregateModel = WriteAggregateResponse
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
输入句子:{{sentence}}
|
||||||
|
历史消息:{{history}}
|
||||||
|
|
||||||
|
# 你的角色
|
||||||
|
你是一个擅长事件聚合与语义判断的专家。
|
||||||
|
|
||||||
|
# 你的任务
|
||||||
|
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||||
|
|
||||||
|
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||||
|
- 描述的是同一个具体事件或事实
|
||||||
|
- 存在明显的因果关系、前后发展关系
|
||||||
|
- 是对同一事件的补充、解释、追问或延展
|
||||||
|
- 逻辑上属于同一语境下的连续讨论
|
||||||
|
|
||||||
|
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||||
|
- 话题不同,事件主体不同
|
||||||
|
- 时间、地点、对象明显不同
|
||||||
|
- 只是语义相似,但并非同一具体事件
|
||||||
|
- 无直接事件、因果或逻辑关联
|
||||||
|
|
||||||
|
# 输出规则(非常重要)
|
||||||
|
你必须按照以下JSON格式输出:
|
||||||
|
|
||||||
|
**如果是同一事件:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"is_same_event": true,
|
||||||
|
"output": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**如果不是同一事件:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"is_same_event": false,
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "输入句子的内容"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "对应的回复内容"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# JSON Schema
|
||||||
|
{{json_schema}}
|
||||||
|
|
||||||
|
# 注意事项
|
||||||
|
- 必须严格按照上述格式输出
|
||||||
|
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||||
|
- 消息列表必须包含 role 和 content 字段
|
||||||
|
- 不要输出任何解释、分析或多余内容
|
||||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, List, Dict, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_messages(messages: Any) -> str:
|
||||||
|
"""
|
||||||
|
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON 字符串
|
||||||
|
"""
|
||||||
|
if isinstance(messages, str):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
if isinstance(messages, (list, tuple)):
|
||||||
|
# 检查是否是 LangChain 消息对象列表
|
||||||
|
serialized_list = []
|
||||||
|
for msg in messages:
|
||||||
|
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||||
|
# LangChain 消息对象
|
||||||
|
serialized_list.append({
|
||||||
|
'type': msg.type,
|
||||||
|
'content': msg.content,
|
||||||
|
'role': getattr(msg, 'role', msg.type)
|
||||||
|
})
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
serialized_list.append(msg)
|
||||||
|
else:
|
||||||
|
serialized_list.append(str(msg))
|
||||||
|
return json.dumps(serialized_list, ensure_ascii=False)
|
||||||
|
|
||||||
|
if isinstance(messages, dict):
|
||||||
|
return json.dumps(messages, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 其他类型转为字符串
|
||||||
|
return str(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_messages(messages_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
将 JSON 字符串反序列化为原始格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_str: JSON 字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
反序列化后的对象(list、dict 或 string)
|
||||||
|
"""
|
||||||
|
if not messages_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(messages_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return messages_str
|
||||||
|
|
||||||
|
|
||||||
|
def fix_encoding(text: str) -> str:
|
||||||
|
"""
|
||||||
|
修复错误编码的文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要修复的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 修复后的文本
|
||||||
|
"""
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
return text
|
||||||
|
try:
|
||||||
|
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||||
|
return text.encode('latin-1').decode('utf-8')
|
||||||
|
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||||
|
# 如果修复失败,返回原文本
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
格式化会话数据为统一的输出格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 原始会话数据
|
||||||
|
include_time: 是否包含时间字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"Query": fix_encoding(data.get('messages', '')),
|
||||||
|
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_time:
|
||||||
|
result["starttime"] = data.get('starttime', '')
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
根据时间范围过滤数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: 包含 starttime 字段的数据列表
|
||||||
|
minutes: 时间范围(分钟)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 过滤后的数据列表
|
||||||
|
"""
|
||||||
|
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||||
|
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
filtered_items = []
|
||||||
|
for item in items:
|
||||||
|
starttime = item.get('starttime', '')
|
||||||
|
if starttime and starttime >= time_threshold_str:
|
||||||
|
filtered_items.append(item)
|
||||||
|
|
||||||
|
return filtered_items
|
||||||
|
|
||||||
|
|
||||||
|
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||||
|
remove_time: bool = True) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
对结果进行排序、限制数量并移除时间字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: 数据列表
|
||||||
|
limit: 最大返回数量
|
||||||
|
remove_time: 是否移除 starttime 字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 处理后的数据列表
|
||||||
|
"""
|
||||||
|
# 按时间降序排序(最新的在前)
|
||||||
|
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
|
# 限制数量
|
||||||
|
result_items = items[:limit]
|
||||||
|
|
||||||
|
# 移除 starttime 字段
|
||||||
|
if remove_time:
|
||||||
|
for item in result_items:
|
||||||
|
item.pop('starttime', None)
|
||||||
|
|
||||||
|
# 如果结果少于1条,返回空列表
|
||||||
|
if len(result_items) < 1:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||||
|
"""
|
||||||
|
生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
key_type: key 类型 ("session", "read", "write", "count")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Redis key
|
||||||
|
"""
|
||||||
|
if key_type == "count":
|
||||||
|
return f"session:count:{session_id}"
|
||||||
|
elif key_type == "write":
|
||||||
|
return f"session:write:{session_id}"
|
||||||
|
elif key_type == "session" or key_type == "read":
|
||||||
|
return f"session:{session_id}"
|
||||||
|
else:
|
||||||
|
return f"session:{session_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_timestamp() -> str:
|
||||||
|
"""
|
||||||
|
获取当前时间戳字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||||
|
"""
|
||||||
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
@@ -1,11 +1,36 @@
|
|||||||
import redis
|
import redis
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_base import (
|
||||||
|
serialize_messages,
|
||||||
|
deserialize_messages,
|
||||||
|
fix_encoding,
|
||||||
|
format_session_data,
|
||||||
|
filter_by_time_range,
|
||||||
|
sort_and_limit_results,
|
||||||
|
generate_session_key,
|
||||||
|
get_current_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RedisSessionStore:
|
|
||||||
|
|
||||||
|
class RedisWriteStore:
|
||||||
|
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
|
"""
|
||||||
|
初始化 Redis 连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Redis 主机地址
|
||||||
|
port: Redis 端口
|
||||||
|
db: Redis 数据库编号
|
||||||
|
password: Redis 密码
|
||||||
|
session_id: 会话ID
|
||||||
|
"""
|
||||||
self.r = redis.Redis(
|
self.r = redis.Redis(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -16,32 +41,437 @@ class RedisSessionStore:
|
|||||||
)
|
)
|
||||||
self.uudi = session_id
|
self.uudi = session_id
|
||||||
|
|
||||||
def _fix_encoding(self, text):
|
def save_session_write(self, userid: str, messages: str) -> str:
|
||||||
"""修复错误编码的文本"""
|
|
||||||
if not text or not isinstance(text, str):
|
|
||||||
return text
|
|
||||||
try:
|
|
||||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
|
||||||
return text.encode('latin-1').decode('utf-8')
|
|
||||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
|
||||||
# 如果修复失败,返回原文本
|
|
||||||
return text
|
|
||||||
|
|
||||||
# 修改后的 save_session 方法
|
|
||||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
优化版本:确保写入时间不超过1秒
|
|
||||||
|
Args:
|
||||||
|
userid: 用户ID
|
||||||
|
messages: 用户消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 新生成的 session_id
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
messages = serialize_messages(messages)
|
||||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
session_id = str(uuid.uuid4())
|
||||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
key = generate_session_key(session_id, key_type="write")
|
||||||
|
|
||||||
# 使用 pipeline 批量写入,减少网络往返
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
|
pipe.hset(key, mapping={
|
||||||
|
"id": self.uudi,
|
||||||
|
"sessionid": userid,
|
||||||
|
"messages": messages,
|
||||||
|
"starttime": get_current_timestamp()
|
||||||
|
})
|
||||||
|
result = pipe.execute()
|
||||||
|
|
||||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
|
return session_id
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[save_session_write] 保存会话失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||||
|
"""
|
||||||
|
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid: 用户ID (对应 sessionid 字段)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 只查询 write 类型的 key
|
||||||
|
keys = self.r.keys('session:write:*')
|
||||||
|
if not keys:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 批量获取数据
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.hgetall(key)
|
||||||
|
all_data = pipe.execute()
|
||||||
|
|
||||||
|
# 筛选符合 userid 的数据
|
||||||
|
results = []
|
||||||
|
for key, data in zip(keys, all_data):
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
|
if data.get('sessionid') == userid:
|
||||||
|
# 从 key 中提取 session_id: session:write:{session_id}
|
||||||
|
session_id = key.split(':')[-1]
|
||||||
|
results.append({
|
||||||
|
"sessionid": session_id,
|
||||||
|
"messages": fix_encoding(data.get('messages', ''))
|
||||||
|
})
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||||
|
"""
|
||||||
|
通过 end_user_id 获取所有 write 类型的会话数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"session_id": "uuid",
|
||||||
|
"id": "...",
|
||||||
|
"sessionid": "end_user_id",
|
||||||
|
"messages": "...",
|
||||||
|
"starttime": "timestamp"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 只查询 write 类型的 key
|
||||||
|
keys = self.r.keys('session:write:*')
|
||||||
|
if not keys:
|
||||||
|
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 批量获取数据
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.hgetall(key)
|
||||||
|
all_data = pipe.execute()
|
||||||
|
|
||||||
|
# 筛选符合 end_user_id 的数据
|
||||||
|
results = []
|
||||||
|
for key, data in zip(keys, all_data):
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
|
if data.get('sessionid') == end_user_id:
|
||||||
|
# 从 key 中提取 session_id: session:write:{session_id}
|
||||||
|
session_id = key.split(':')[-1]
|
||||||
|
|
||||||
|
# 构建完整的会话信息
|
||||||
|
session_info = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"id": data.get('id', ''),
|
||||||
|
"sessionid": data.get('sessionid', ''),
|
||||||
|
"messages": fix_encoding(data.get('messages', '')),
|
||||||
|
"starttime": data.get('starttime', '')
|
||||||
|
}
|
||||||
|
results.append(session_info)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 按时间排序(最新的在前)
|
||||||
|
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
|
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_user_recent_sessions(self, userid: str,
|
||||||
|
minutes: int = 5) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid: 用户ID (对应 sessionid 字段)
|
||||||
|
minutes: 查询最近几分钟的数据,默认5分钟
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 只查询 write 类型的 key
|
||||||
|
keys = self.r.keys('session:write:*')
|
||||||
|
if not keys:
|
||||||
|
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 批量获取数据
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.hgetall(key)
|
||||||
|
all_data = pipe.execute()
|
||||||
|
|
||||||
|
# 筛选符合 userid 的数据
|
||||||
|
matched_items = []
|
||||||
|
for data in all_data:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
|
if data.get('sessionid') == userid and data.get('starttime'):
|
||||||
|
# write 类型没有 aimessages,所以 Answer 为空
|
||||||
|
matched_items.append({
|
||||||
|
"Query": fix_encoding(data.get('messages', '')),
|
||||||
|
"Answer": "",
|
||||||
|
"starttime": data.get('starttime', '')
|
||||||
|
})
|
||||||
|
|
||||||
|
# 根据时间范围过滤
|
||||||
|
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||||
|
# 排序并移除时间字段
|
||||||
|
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||||
|
print(result_items)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||||
|
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
def delete_all_write_sessions(self) -> int:
|
||||||
|
"""
|
||||||
|
删除所有 write 类型的会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的数量
|
||||||
|
"""
|
||||||
|
keys = self.r.keys('session:write:*')
|
||||||
|
if keys:
|
||||||
|
return self.r.delete(*keys)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCountStore:
|
||||||
|
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
|
"""
|
||||||
|
初始化 Redis 连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Redis 主机地址
|
||||||
|
port: Redis 端口
|
||||||
|
db: Redis 数据库编号
|
||||||
|
password: Redis 密码
|
||||||
|
session_id: 会话ID
|
||||||
|
"""
|
||||||
|
self.r = redis.Redis(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
db=db,
|
||||||
|
password=password,
|
||||||
|
decode_responses=True,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
self.uudi = session_id
|
||||||
|
|
||||||
|
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||||
|
"""
|
||||||
|
保存用户访问次数统计
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
count: 访问次数
|
||||||
|
messages: 消息内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 新生成的 session_id
|
||||||
|
"""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
key = generate_session_key(session_id, key_type="count")
|
||||||
|
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||||
|
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
pipe.hset(key, mapping={
|
||||||
|
"id": self.uudi,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"count": int(count),
|
||||||
|
"messages": serialize_messages(messages),
|
||||||
|
"starttime": get_current_timestamp()
|
||||||
|
})
|
||||||
|
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||||
|
|
||||||
|
# 创建索引:end_user_id -> session_id 映射
|
||||||
|
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||||
|
|
||||||
|
result = pipe.execute()
|
||||||
|
|
||||||
|
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||||
|
"""
|
||||||
|
通过 end_user_id 查询访问次数统计
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用索引键快速查找
|
||||||
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
|
try:
|
||||||
|
key_type = self.r.type(index_key)
|
||||||
|
if key_type != 'string' and key_type != 'none':
|
||||||
|
self.r.delete(index_key)
|
||||||
|
return False
|
||||||
|
except Exception as type_error:
|
||||||
|
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 直接获取数据
|
||||||
|
key = generate_session_key(session_id, key_type="count")
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
# 索引存在但数据不存在,清理索引
|
||||||
|
self.r.delete(index_key)
|
||||||
|
return False
|
||||||
|
|
||||||
|
count = data.get('count')
|
||||||
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
|
if count is not None:
|
||||||
|
messages = deserialize_messages(messages_str)
|
||||||
|
return [int(count), messages]
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[get_sessions_count] 查询失败: {e}")
|
||||||
|
return False
|
||||||
|
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||||
|
messages: Any) -> bool:
|
||||||
|
"""
|
||||||
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
new_count: 新的 count 值
|
||||||
|
messages: 消息内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 更新成功返回 True,未找到记录返回 False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用索引键快速查找
|
||||||
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
|
try:
|
||||||
|
key_type = self.r.type(index_key)
|
||||||
|
if key_type != 'string' and key_type != 'none':
|
||||||
|
# 索引键类型错误,删除并返回 False
|
||||||
|
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
|
self.r.delete(index_key)
|
||||||
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
|
return False
|
||||||
|
except Exception as type_error:
|
||||||
|
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 直接更新数据
|
||||||
|
key = generate_session_key(session_id, key_type="count")
|
||||||
|
messages_str = serialize_messages(messages)
|
||||||
|
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
pipe.hset(key, 'count', int(new_count))
|
||||||
|
pipe.hset(key, 'messages', messages_str)
|
||||||
|
result = pipe.execute()
|
||||||
|
|
||||||
|
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[update_sessions_count] 更新失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_all_count_sessions(self) -> int:
|
||||||
|
"""
|
||||||
|
删除所有 count 类型的会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的数量
|
||||||
|
"""
|
||||||
|
keys = self.r.keys('session:count:*')
|
||||||
|
if keys:
|
||||||
|
return self.r.delete(*keys)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSessionStore:
|
||||||
|
"""Redis 会话存储类,用于管理会话数据"""
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
|
"""
|
||||||
|
初始化 Redis 连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: Redis 主机地址
|
||||||
|
port: Redis 端口
|
||||||
|
db: Redis 数据库编号
|
||||||
|
password: Redis 密码
|
||||||
|
session_id: 会话ID
|
||||||
|
"""
|
||||||
|
self.r = redis.Redis(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
db=db,
|
||||||
|
password=password,
|
||||||
|
decode_responses=True,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
self.uudi = session_id
|
||||||
|
|
||||||
|
# ==================== 写入操作 ====================
|
||||||
|
|
||||||
|
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||||
|
apply_id: str, end_user_id: str) -> str:
|
||||||
|
"""
|
||||||
|
写入一条会话数据,返回 session_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
userid: 用户ID
|
||||||
|
messages: 用户消息
|
||||||
|
aimessages: AI回复消息
|
||||||
|
apply_id: 应用ID
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 新生成的 session_id
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
key = generate_session_key(session_id, key_type="read")
|
||||||
|
|
||||||
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": userid,
|
"sessionid": userid,
|
||||||
@@ -49,177 +479,195 @@ class RedisSessionStore:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"aimessages": aimessages,
|
"aimessages": aimessages,
|
||||||
"starttime": starttime
|
"starttime": get_current_timestamp()
|
||||||
})
|
})
|
||||||
|
|
||||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
|
||||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
|
||||||
|
|
||||||
# 执行批量操作
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id # 返回新生成的 session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"保存会话失败: {e}")
|
print(f"[save_session] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def save_sessions_batch(self, sessions_data):
|
# ==================== 读取操作 ====================
|
||||||
"""
|
|
||||||
批量写入多条会话数据,返回 session_id 列表
|
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
|
||||||
优化版本:批量操作,大幅提升性能
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session_ids = []
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
|
|
||||||
for session in sessions_data:
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
key = f"session:{session_id}"
|
|
||||||
|
|
||||||
pipe.hset(key, mapping={
|
|
||||||
"id": self.uudi,
|
|
||||||
"sessionid": session.get('userid'),
|
|
||||||
"apply_id": session.get('apply_id'),
|
|
||||||
"end_user_id": session.get('end_user_id'),
|
|
||||||
"messages": session.get('messages'),
|
|
||||||
"aimessages": session.get('aimessages'),
|
|
||||||
"starttime": starttime
|
|
||||||
})
|
|
||||||
|
|
||||||
session_ids.append(session_id)
|
|
||||||
|
|
||||||
# 一次性执行所有写入操作
|
|
||||||
results = pipe.execute()
|
|
||||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
|
||||||
return session_ids
|
|
||||||
except Exception as e:
|
|
||||||
print(f"批量保存会话失败: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# ---------------- 读取 ----------------
|
|
||||||
def get_session(self, session_id):
|
|
||||||
"""
|
"""
|
||||||
读取一条会话数据
|
读取一条会话数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict 或 None: 会话数据
|
||||||
"""
|
"""
|
||||||
key = f"session:{session_id}"
|
key = generate_session_key(session_id)
|
||||||
data = self.r.hgetall(key)
|
data = self.r.hgetall(key)
|
||||||
return data if data else None
|
return data if data else None
|
||||||
|
|
||||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
获取所有会话数据(不包括 count 和 write 类型)
|
||||||
"""
|
|
||||||
result_items = []
|
Returns:
|
||||||
|
Dict: 所有会话数据,key 为 session_id
|
||||||
# 遍历所有会话数据
|
|
||||||
for key in self.r.keys('session:*'):
|
|
||||||
data = self.r.hgetall(key)
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查三个条件是否都匹配
|
|
||||||
if (data.get('sessionid') == sessionid and
|
|
||||||
data.get('apply_id') == apply_id and
|
|
||||||
data.get('end_user_id') == end_user_id):
|
|
||||||
result_items.append(data)
|
|
||||||
|
|
||||||
return result_items
|
|
||||||
|
|
||||||
def get_all_sessions(self):
|
|
||||||
"""
|
|
||||||
获取所有会话数据
|
|
||||||
"""
|
"""
|
||||||
sessions = {}
|
sessions = {}
|
||||||
for key in self.r.keys('session:*'):
|
for key in self.r.keys('session:*'):
|
||||||
sid = key.split(':')[1]
|
# 排除 count 和 write 类型的 key
|
||||||
sessions[sid] = self.get_session(sid)
|
if ':count:' not in key and ':write:' not in key:
|
||||||
|
sid = key.split(':')[1]
|
||||||
|
sessions[sid] = self.get_session(sid)
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
# ---------------- 更新 ----------------
|
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||||
def update_session(self, session_id, field, value):
|
end_user_id: str) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sessionid: 会话ID(支持模糊匹配)
|
||||||
|
apply_id: 应用ID
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
keys = self.r.keys('session:*')
|
||||||
|
if not keys:
|
||||||
|
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 批量获取数据
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
# 排除 count 和 write 类型
|
||||||
|
if ':count:' not in key and ':write:' not in key:
|
||||||
|
pipe.hgetall(key)
|
||||||
|
all_data = pipe.execute()
|
||||||
|
|
||||||
|
# 筛选符合条件的数据
|
||||||
|
matched_items = []
|
||||||
|
for data in all_data:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (data.get('apply_id') == apply_id and
|
||||||
|
data.get('end_user_id') == end_user_id):
|
||||||
|
# 支持模糊匹配或完全匹配 sessionid
|
||||||
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
|
matched_items.append(format_session_data(data, include_time=True))
|
||||||
|
|
||||||
|
# 排序、限制数量并移除时间字段
|
||||||
|
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
# ==================== 更新操作 ====================
|
||||||
|
|
||||||
|
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
更新单个字段
|
更新单个字段
|
||||||
优化版本:使用 pipeline 减少网络往返
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
field: 字段名
|
||||||
|
value: 字段值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否更新成功
|
||||||
"""
|
"""
|
||||||
key = f"session:{session_id}"
|
key = generate_session_key(session_id)
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.exists(key)
|
pipe.exists(key)
|
||||||
pipe.hset(key, field, value)
|
pipe.hset(key, field, value)
|
||||||
results = pipe.execute()
|
results = pipe.execute()
|
||||||
return bool(results[0]) # 返回 key 是否存在
|
return bool(results[0])
|
||||||
|
|
||||||
# ---------------- 删除 ----------------
|
# ==================== 删除操作 ====================
|
||||||
def delete_session(self, session_id):
|
|
||||||
|
def delete_session(self, session_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
删除单条会话
|
删除单条会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的数量
|
||||||
"""
|
"""
|
||||||
key = f"session:{session_id}"
|
key = generate_session_key(session_id)
|
||||||
return self.r.delete(key)
|
return self.r.delete(key)
|
||||||
|
|
||||||
def delete_all_sessions(self):
|
def delete_all_sessions(self) -> int:
|
||||||
"""
|
"""
|
||||||
删除所有会话
|
删除所有会话(不包括 count 和 write 类型)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的数量
|
||||||
"""
|
"""
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if keys:
|
# 过滤掉 count 和 write 类型
|
||||||
return self.r.delete(*keys)
|
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||||
|
if keys_to_delete:
|
||||||
|
return self.r.delete(*keys_to_delete)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def delete_duplicate_sessions(self):
|
def delete_duplicate_sessions(self) -> int:
|
||||||
"""
|
"""
|
||||||
删除重复会话数据,条件:
|
删除重复会话数据(不包括 count 和 write 类型)
|
||||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
|
||||||
|
Returns:
|
||||||
|
int: 删除的数量
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 第一步:使用 pipeline 批量获取所有 key
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
print("[delete_duplicate_sessions] 没有会话数据")
|
print("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 第二步:使用 pipeline 批量获取所有数据
|
# 批量获取所有数据
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
pipe.hgetall(key)
|
# 排除 count 和 write 类型
|
||||||
|
if ':count:' not in key and ':write:' not in key:
|
||||||
|
pipe.hgetall(key)
|
||||||
all_data = pipe.execute()
|
all_data = pipe.execute()
|
||||||
|
|
||||||
# 第三步:在内存中识别重复数据
|
# 识别重复数据
|
||||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
seen = {}
|
||||||
keys_to_delete = [] # 需要删除的 key 列表
|
keys_to_delete = []
|
||||||
|
|
||||||
for key, data in zip(keys, all_data, strict=False):
|
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取五个字段的值
|
|
||||||
sessionid = data.get('sessionid', '')
|
|
||||||
user_id = data.get('id', '')
|
|
||||||
end_user_id = data.get('end_user_id', '')
|
|
||||||
messages = data.get('messages', '')
|
|
||||||
aimessages = data.get('aimessages', '')
|
|
||||||
|
|
||||||
# 用五元组作为唯一标识
|
# 用五元组作为唯一标识
|
||||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
identifier = (
|
||||||
|
data.get('sessionid', ''),
|
||||||
|
data.get('id', ''),
|
||||||
|
data.get('end_user_id', ''),
|
||||||
|
data.get('messages', ''),
|
||||||
|
data.get('aimessages', '')
|
||||||
|
)
|
||||||
|
|
||||||
if identifier in seen:
|
if identifier in seen:
|
||||||
# 重复,标记为待删除
|
|
||||||
keys_to_delete.append(key)
|
keys_to_delete.append(key)
|
||||||
else:
|
else:
|
||||||
# 第一次出现,记录
|
|
||||||
seen[identifier] = key
|
seen[identifier] = key
|
||||||
|
|
||||||
# 第四步:使用 pipeline 批量删除重复的 key
|
# 批量删除重复的 key
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
if keys_to_delete:
|
if keys_to_delete:
|
||||||
# 分批删除,避免单次操作过大
|
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
for i in range(0, len(keys_to_delete), batch_size):
|
for i in range(0, len(keys_to_delete), batch_size):
|
||||||
batch = keys_to_delete[i:i + batch_size]
|
batch = keys_to_delete[i:i + batch_size]
|
||||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
|||||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
def find_user_session(self, sessionid):
|
|
||||||
user_id = sessionid
|
|
||||||
|
|
||||||
result_items = []
|
|
||||||
for key, values in store.get_all_sessions().items():
|
|
||||||
history = {}
|
|
||||||
if user_id == str(values['sessionid']):
|
|
||||||
history["Query"] = values['messages']
|
|
||||||
history["Answer"] = values['aimessages']
|
|
||||||
result_items.append(history)
|
|
||||||
|
|
||||||
if len(result_items) <= 1:
|
|
||||||
result_items = []
|
|
||||||
return (result_items)
|
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
|
||||||
"""
|
|
||||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
# 使用 pipeline 批量获取数据,提高性能
|
|
||||||
keys = self.r.keys('session:*')
|
|
||||||
|
|
||||||
if not keys:
|
|
||||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 使用 pipeline 批量获取所有 hash 数据
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
for key in keys:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
|
||||||
|
|
||||||
# 解析并筛选符合条件的数据
|
|
||||||
matched_items = []
|
|
||||||
for data in all_data:
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查是否符合三个条件
|
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
|
||||||
data.get('end_user_id') == end_user_id):
|
|
||||||
# 支持模糊匹配 sessionid 或者完全匹配
|
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
|
||||||
matched_items.append({
|
|
||||||
"Query": self._fix_encoding(data.get('messages')),
|
|
||||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
|
||||||
"starttime": data.get('starttime', '')
|
|
||||||
})
|
|
||||||
# 按时间降序排序(最新的在前)
|
|
||||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
|
||||||
# 只保留最新的6条
|
|
||||||
result_items = matched_items[:6]
|
|
||||||
# # 移除 starttime 字段
|
|
||||||
for item in result_items:
|
|
||||||
item.pop('starttime', None)
|
|
||||||
|
|
||||||
# 如果结果少于等于1条,返回空列表
|
|
||||||
if len(result_items) <= 1:
|
|
||||||
result_items = []
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
|
||||||
|
|
||||||
return result_items
|
|
||||||
|
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
store = RedisSessionStore(
|
store = RedisSessionStore(
|
||||||
host=settings.REDIS_HOST,
|
host=settings.REDIS_HOST,
|
||||||
port=settings.REDIS_PORT,
|
port=settings.REDIS_PORT,
|
||||||
db=settings.REDIS_DB,
|
db=settings.REDIS_DB,
|
||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||||
session_id=str(uuid.uuid4())
|
session_id=str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
write_store = RedisWriteStore(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||||
|
session_id=str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|
||||||
|
count_store = RedisCountStore(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||||
|
session_id=str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
|||||||
This module provides the main write function for executing the knowledge extraction
|
This module provides the main write function for executing the knowledge extraction
|
||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -123,23 +124,48 @@ async def write(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 添加死锁重试机制
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 秒
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
success = await save_dialog_and_statements_to_neo4j(
|
||||||
|
dialogue_nodes=all_dialogue_nodes,
|
||||||
|
chunk_nodes=all_chunk_nodes,
|
||||||
|
statement_nodes=all_statement_nodes,
|
||||||
|
entity_nodes=all_entity_nodes,
|
||||||
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
|
entity_edges=all_entity_entity_edges,
|
||||||
|
connector=neo4j_connector
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
# 检查是否是死锁错误
|
||||||
|
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# 非死锁错误,直接抛出
|
||||||
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success = await save_dialog_and_statements_to_neo4j(
|
|
||||||
dialogue_nodes=all_dialogue_nodes,
|
|
||||||
chunk_nodes=all_chunk_nodes,
|
|
||||||
statement_nodes=all_statement_nodes,
|
|
||||||
entity_nodes=all_entity_nodes,
|
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
|
||||||
entity_edges=all_entity_entity_edges,
|
|
||||||
connector=neo4j_connector
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
|
||||||
finally:
|
|
||||||
await neo4j_connector.close()
|
await neo4j_connector.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing Neo4j connector: {e}")
|
||||||
|
|
||||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
|
|||||||
TripletExtractionResponse,
|
TripletExtractionResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ontology models
|
||||||
|
from app.core.memory.models.ontology_models import (
|
||||||
|
OntologyClass,
|
||||||
|
OntologyExtractionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
# Variable configuration models
|
# Variable configuration models
|
||||||
from app.core.memory.models.variate_config import (
|
from app.core.memory.models.variate_config import (
|
||||||
StatementExtractionConfig,
|
StatementExtractionConfig,
|
||||||
@@ -105,6 +111,9 @@ __all__ = [
|
|||||||
"Entity",
|
"Entity",
|
||||||
"Triplet",
|
"Triplet",
|
||||||
"TripletExtractionResponse",
|
"TripletExtractionResponse",
|
||||||
|
# Ontology models
|
||||||
|
"OntologyClass",
|
||||||
|
"OntologyExtractionResponse",
|
||||||
# Variable configuration
|
# Variable configuration
|
||||||
"StatementExtractionConfig",
|
"StatementExtractionConfig",
|
||||||
"ForgettingEngineConfig",
|
"ForgettingEngineConfig",
|
||||||
|
|||||||
@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
|
|||||||
description="Entity aliases - alternative names for this entity"
|
description="Entity aliases - alternative names for this entity"
|
||||||
)
|
)
|
||||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
|
|||||||
135
api/app/core/memory/models/ontology_models.py
Normal file
135
api/app/core/memory/models/ontology_models.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Models for ontology classes and extraction responses.
|
||||||
|
|
||||||
|
This module contains Pydantic models for representing extracted ontology classes
|
||||||
|
from scenario descriptions, following OWL ontology engineering standards.
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyClass: Represents an extracted ontology class
|
||||||
|
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyClass(BaseModel):
|
||||||
|
"""Represents an extracted ontology class from scenario description.
|
||||||
|
|
||||||
|
An ontology class represents an abstract category or concept in a domain,
|
||||||
|
following OWL ontology engineering standards and naming conventions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique string identifier for the ontology class
|
||||||
|
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||||
|
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||||
|
description: Textual description of the class
|
||||||
|
examples: List of concrete instance examples of this class
|
||||||
|
parent_class: Optional name of the parent class in the hierarchy
|
||||||
|
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||||
|
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||||
|
|
||||||
|
Config:
|
||||||
|
extra: Ignore extra fields from LLM output
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra='ignore')
|
||||||
|
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: uuid4().hex,
|
||||||
|
description="Unique identifier for the ontology class"
|
||||||
|
)
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="Name of the class in PascalCase format"
|
||||||
|
)
|
||||||
|
name_chinese: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Chinese translation of the class name"
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
description="Description of the class"
|
||||||
|
)
|
||||||
|
examples: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of concrete instance examples"
|
||||||
|
)
|
||||||
|
parent_class: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Name of the parent class in the hierarchy"
|
||||||
|
)
|
||||||
|
entity_type: str = Field(
|
||||||
|
...,
|
||||||
|
description="Type/category of the entity"
|
||||||
|
)
|
||||||
|
domain: str = Field(
|
||||||
|
...,
|
||||||
|
description="Domain this class belongs to"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('name')
|
||||||
|
@classmethod
|
||||||
|
def validate_pascal_case(cls, v: str) -> str:
|
||||||
|
"""Validate that the class name follows PascalCase convention.
|
||||||
|
|
||||||
|
PascalCase rules:
|
||||||
|
- Must start with an uppercase letter
|
||||||
|
- Cannot contain spaces
|
||||||
|
- Should not contain special characters except underscores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: The class name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated class name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the name doesn't follow PascalCase convention
|
||||||
|
"""
|
||||||
|
if not v:
|
||||||
|
raise ValueError("Class name cannot be empty")
|
||||||
|
|
||||||
|
if not v[0].isupper():
|
||||||
|
raise ValueError(
|
||||||
|
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if ' ' in v:
|
||||||
|
raise ValueError(
|
||||||
|
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for invalid characters (allow alphanumeric and underscore only)
|
||||||
|
if not all(c.isalnum() or c == '_' for c in v):
|
||||||
|
raise ValueError(
|
||||||
|
f"Class name '{v}' contains invalid characters. "
|
||||||
|
"Only alphanumeric characters and underscores are allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyExtractionResponse(BaseModel):
|
||||||
|
"""Response model for ontology extraction from LLM.
|
||||||
|
|
||||||
|
This model represents the structured output from the LLM when
|
||||||
|
extracting ontology classes from scenario descriptions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
classes: List of extracted ontology classes
|
||||||
|
domain: Domain/field the scenario belongs to
|
||||||
|
|
||||||
|
Config:
|
||||||
|
extra: Ignore extra fields from LLM output
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra='ignore')
|
||||||
|
|
||||||
|
classes: List[OntologyClass] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of extracted ontology classes"
|
||||||
|
)
|
||||||
|
domain: str = Field(
|
||||||
|
...,
|
||||||
|
description="Domain/field the scenario belongs to"
|
||||||
|
)
|
||||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
|||||||
if len(desc_b) > len(desc_a):
|
if len(desc_b) > len(desc_a):
|
||||||
canonical.description = desc_b
|
canonical.description = desc_b
|
||||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||||
def _extract_sources(txt: str) -> List[str]:
|
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||||
sources: List[str] = []
|
# def _extract_sources(txt: str) -> List[str]:
|
||||||
if not txt:
|
# sources: List[str] = []
|
||||||
return sources
|
# if not txt:
|
||||||
for line in str(txt).splitlines():
|
# return sources
|
||||||
ln = line.strip()
|
# for line in str(txt).splitlines():
|
||||||
|
# ln = line.strip()
|
||||||
# 支持“来源:”或“来源:”前缀
|
# 支持“来源:”或“来源:”前缀
|
||||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||||
if m:
|
# if m:
|
||||||
content = m.group(1).strip()
|
# content = m.group(1).strip()
|
||||||
if content:
|
# if content:
|
||||||
sources.append(content)
|
# sources.append(content)
|
||||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||||
if not sources and txt.strip():
|
# if not sources and txt.strip():
|
||||||
sources.append(txt.strip())
|
# sources.append(txt.strip())
|
||||||
return sources
|
# return sources
|
||||||
try:
|
try:
|
||||||
src_a = _extract_sources(fact_a)
|
# src_a = _extract_sources(fact_a)
|
||||||
src_b = _extract_sources(fact_b)
|
# src_b = _extract_sources(fact_b)
|
||||||
seen = set()
|
# seen = set()
|
||||||
merged_sources: List[str] = []
|
# merged_sources: List[str] = []
|
||||||
for s in src_a + src_b:
|
# for s in src_a + src_b:
|
||||||
if s and s not in seen:
|
# if s and s not in seen:
|
||||||
seen.add(s)
|
# seen.add(s)
|
||||||
merged_sources.append(s)
|
# merged_sources.append(s)
|
||||||
if merged_sources:
|
# if merged_sources:
|
||||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||||
elif fact_b and not fact_a:
|
# elif fact_b and not fact_a:
|
||||||
canonical.fact_summary = fact_b
|
# canonical.fact_summary = fact_b
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
# 兜底:若解析失败,保留较长文本
|
# 兜底:若解析失败,保留较长文本
|
||||||
if len(fact_b) > len(fact_a):
|
# if len(fact_b) > len(fact_a):
|
||||||
canonical.fact_summary = fact_b
|
# canonical.fact_summary = fact_b
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
|||||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||||
desc_a = (getattr(a, "description", "") or "")
|
desc_a = (getattr(a, "description", "") or "")
|
||||||
desc_b = (getattr(b, "description", "") or "")
|
desc_b = (getattr(b, "description", "") or "")
|
||||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||||
score_a = len(desc_a) + len(fact_a)
|
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||||
score_b = len(desc_b) + len(fact_b)
|
# score_a = len(desc_a) + len(fact_a)
|
||||||
|
# score_b = len(desc_b) + len(fact_b)
|
||||||
|
score_a = len(desc_a)
|
||||||
|
score_b = len(desc_b)
|
||||||
if score_a != score_b:
|
if score_a != score_b:
|
||||||
return 0 if score_a >= score_b else 1
|
return 0 if score_a >= score_b else 1
|
||||||
return 0
|
return 0
|
||||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
"fact_summary": getattr(a, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(a, "fact_summary", None),
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
"fact_summary": getattr(b, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(b, "fact_summary", None),
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
"fact_summary": getattr(a, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(a, "fact_summary", None),
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
"fact_summary": getattr(b, "fact_summary", None),
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# "fact_summary": getattr(b, "fact_summary", None),
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
prompt = render_entity_dedup_prompt(
|
prompt = render_entity_dedup_prompt(
|
||||||
|
|||||||
@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
description=row.get("description") or "",
|
description=row.get("description") or "",
|
||||||
aliases=row.get("aliases") or [],
|
aliases=row.get("aliases") or [],
|
||||||
name_embedding=row.get("name_embedding") or [],
|
name_embedding=row.get("name_embedding") or [],
|
||||||
fact_summary=row.get("fact_summary") or "",
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary=row.get("fact_summary") or "",
|
||||||
connect_strength=row.get("connect_strength") or "",
|
connect_strength=row.get("connect_strength") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1085,7 +1085,8 @@ class ExtractionOrchestrator:
|
|||||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
|
|||||||
@@ -8,4 +8,5 @@
|
|||||||
- TemporalExtractor: 时间信息提取
|
- TemporalExtractor: 时间信息提取
|
||||||
- EmbeddingGenerator: 嵌入向量生成
|
- EmbeddingGenerator: 嵌入向量生成
|
||||||
- MemorySummaryGenerator: 记忆摘要生成
|
- MemorySummaryGenerator: 记忆摘要生成
|
||||||
|
- OntologyExtractor: 本体类提取
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,6 +14,34 @@ from pydantic import Field
|
|||||||
|
|
||||||
logger = get_memory_logger(__name__)
|
logger = get_memory_logger(__name__)
|
||||||
|
|
||||||
|
# 支持的语言列表和默认回退值
|
||||||
|
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||||
|
FALLBACK_LANGUAGE = "en"
|
||||||
|
|
||||||
|
|
||||||
|
def validate_language(language: Optional[str]) -> str:
|
||||||
|
"""
|
||||||
|
校验语言参数,确保其为有效值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language: 待校验的语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
有效的语言代码("zh" 或 "en")
|
||||||
|
"""
|
||||||
|
if language is None:
|
||||||
|
return FALLBACK_LANGUAGE
|
||||||
|
|
||||||
|
lang = str(language).lower().strip()
|
||||||
|
if lang in SUPPORTED_LANGUAGES:
|
||||||
|
return lang
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。"
|
||||||
|
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||||
|
)
|
||||||
|
return FALLBACK_LANGUAGE
|
||||||
|
|
||||||
|
|
||||||
class MemorySummaryResponse(RobustLLMResponse):
|
class MemorySummaryResponse(RobustLLMResponse):
|
||||||
"""Structured response for summary generation per chunk.
|
"""Structured response for summary generation per chunk.
|
||||||
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
|
|||||||
|
|
||||||
async def generate_title_and_type_for_summary(
|
async def generate_title_and_type_for_summary(
|
||||||
content: str,
|
content: str,
|
||||||
llm_client
|
llm_client,
|
||||||
|
language: str = None
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
为MemorySummary生成标题和类型
|
为MemorySummary生成标题和类型
|
||||||
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
|
|||||||
Args:
|
Args:
|
||||||
content: Summary的内容文本
|
content: Summary的内容文本
|
||||||
llm_client: LLM客户端实例
|
llm_client: LLM客户端实例
|
||||||
|
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(标题, 类型)元组
|
(标题, 类型)元组
|
||||||
"""
|
"""
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# 如果没有指定语言,从配置中读取,并校验有效性
|
||||||
|
if language is None:
|
||||||
|
language = settings.DEFAULT_LANGUAGE
|
||||||
|
language = validate_language(language)
|
||||||
|
|
||||||
# 定义有效的类型集合
|
# 定义有效的类型集合
|
||||||
VALID_TYPES = {
|
VALID_TYPES = {
|
||||||
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
|
|||||||
}
|
}
|
||||||
DEFAULT_TYPE = "conversation" # 默认类型
|
DEFAULT_TYPE = "conversation" # 默认类型
|
||||||
|
|
||||||
|
# 根据语言设置默认标题
|
||||||
|
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
||||||
|
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
||||||
|
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
||||||
|
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning("content为空,无法生成标题和类型")
|
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
||||||
return ("空内容", DEFAULT_TYPE)
|
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
||||||
|
|
||||||
# 1. 渲染Jinja2提示词模板
|
# 1. 渲染Jinja2提示词模板,传递语言参数
|
||||||
prompt = await render_episodic_title_and_type_prompt(content)
|
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
||||||
|
|
||||||
# 2. 调用LLM生成标题和类型
|
# 2. 调用LLM生成标题和类型
|
||||||
messages = [
|
messages = [
|
||||||
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
|
|||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
result_data = json.loads(json_str)
|
result_data = json.loads(json_str)
|
||||||
title = result_data.get("title", "未知标题")
|
title = result_data.get("title", UNKNOWN_TITLE)
|
||||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||||
|
|
||||||
# 5. 校验和归一化类型
|
# 5. 校验和归一化类型
|
||||||
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
|
|||||||
f"已归一化为 '{episodic_type}'"
|
f"已归一化为 '{episodic_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
||||||
return (title, episodic_type)
|
return (title, episodic_type)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
||||||
return ("解析失败", DEFAULT_TYPE)
|
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
||||||
return ("错误", DEFAULT_TYPE)
|
return (ERROR_TITLE, DEFAULT_TYPE)
|
||||||
|
|
||||||
async def _process_chunk_summary(
|
async def _process_chunk_summary(
|
||||||
dialog: DialogData,
|
dialog: DialogData,
|
||||||
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
|
||||||
|
from app.core.config import settings
|
||||||
|
language = validate_language(settings.DEFAULT_LANGUAGE)
|
||||||
|
|
||||||
# Render prompt via Jinja2 for a single chunk
|
# Render prompt via Jinja2 for a single chunk
|
||||||
prompt_content = await render_memory_summary_prompt(
|
prompt_content = await render_memory_summary_prompt(
|
||||||
chunk_texts=chunk.content,
|
chunk_texts=chunk.content,
|
||||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||||
max_words=200,
|
max_words=200,
|
||||||
|
language=language,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
|
|||||||
try:
|
try:
|
||||||
title, episodic_type = await generate_title_and_type_for_summary(
|
title, episodic_type = await generate_title_and_type_for_summary(
|
||||||
content=summary_text,
|
content=summary_text,
|
||||||
llm_client=llm_client
|
llm_client=llm_client,
|
||||||
|
language=language
|
||||||
)
|
)
|
||||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||||
# Continue without title and type
|
# Continue without title and type
|
||||||
|
|||||||
@@ -0,0 +1,482 @@
|
|||||||
|
"""Ontology class extraction from scenario descriptions using LLM.
|
||||||
|
|
||||||
|
This module provides the OntologyExtractor class for extracting ontology classes
|
||||||
|
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
||||||
|
with two-layer validation (string validation + OWL semantic validation).
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
|
from app.core.memory.models.ontology_models import (
|
||||||
|
OntologyClass,
|
||||||
|
OntologyExtractionResponse,
|
||||||
|
)
|
||||||
|
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
||||||
|
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyExtractor:
|
||||||
|
"""Extractor for ontology classes from scenario descriptions.
|
||||||
|
|
||||||
|
This extractor uses LLM to identify abstract classes and concepts from
|
||||||
|
natural language scenario descriptions, following OWL ontology engineering
|
||||||
|
standards. It performs two-layer validation:
|
||||||
|
1. String validation (naming conventions, reserved words, duplicates)
|
||||||
|
2. OWL semantic validation (consistency checking, circular inheritance)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
llm_client: OpenAI client for LLM calls
|
||||||
|
validator: String validator for class names and descriptions
|
||||||
|
owl_validator: OWL validator for semantic validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_client: OpenAIClient):
|
||||||
|
"""Initialize the OntologyExtractor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_client: OpenAIClient instance for LLM processing
|
||||||
|
"""
|
||||||
|
self.llm_client = llm_client
|
||||||
|
self.validator = OntologyValidator()
|
||||||
|
self.owl_validator = OWLValidator()
|
||||||
|
|
||||||
|
logger.info("OntologyExtractor initialized")
|
||||||
|
|
||||||
|
async def extract_ontology_classes(
|
||||||
|
self,
|
||||||
|
scenario: str,
|
||||||
|
domain: Optional[str] = None,
|
||||||
|
max_classes: int = 15,
|
||||||
|
min_classes: int = 5,
|
||||||
|
enable_owl_validation: bool = True,
|
||||||
|
llm_temperature: float = 0.3,
|
||||||
|
llm_max_tokens: int = 2000,
|
||||||
|
max_description_length: int = 500,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
) -> OntologyExtractionResponse:
|
||||||
|
"""Extract ontology classes from a scenario description.
|
||||||
|
|
||||||
|
This is the main extraction method that orchestrates the entire process:
|
||||||
|
1. Call LLM to extract ontology classes
|
||||||
|
2. Perform first-layer validation (string validation and cleaning)
|
||||||
|
3. Perform second-layer validation (OWL semantic validation)
|
||||||
|
4. Filter invalid classes based on validation errors
|
||||||
|
5. Return validated ontology classes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenario: Natural language scenario description
|
||||||
|
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
||||||
|
max_classes: Maximum number of classes to extract (default: 15)
|
||||||
|
min_classes: Minimum number of classes to extract (default: 5)
|
||||||
|
enable_owl_validation: Whether to enable OWL validation (default: True)
|
||||||
|
llm_temperature: LLM temperature parameter (default: 0.3)
|
||||||
|
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
||||||
|
max_description_length: Maximum description length (default: 500)
|
||||||
|
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OntologyExtractionResponse containing validated ontology classes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If scenario is empty or invalid
|
||||||
|
asyncio.TimeoutError: If extraction times out
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> extractor = OntologyExtractor(llm_client)
|
||||||
|
>>> response = await extractor.extract_ontology_classes(
|
||||||
|
... scenario="A hospital manages patient records...",
|
||||||
|
... domain="Healthcare",
|
||||||
|
... max_classes=10,
|
||||||
|
... timeout=30.0
|
||||||
|
... )
|
||||||
|
>>> len(response.classes)
|
||||||
|
7
|
||||||
|
"""
|
||||||
|
# Start timing
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Validate input
|
||||||
|
if not scenario or not scenario.strip():
|
||||||
|
logger.error("Scenario description is empty")
|
||||||
|
raise ValueError("Scenario description cannot be empty")
|
||||||
|
|
||||||
|
scenario = scenario.strip()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||||
|
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||||
|
f"timeout={timeout}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Call LLM for extraction with timeout
|
||||||
|
logger.info("Step 1: Calling LLM for ontology extraction")
|
||||||
|
llm_start_time = time.time()
|
||||||
|
|
||||||
|
if timeout is not None:
|
||||||
|
# Wrap LLM call with timeout
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self._call_llm_for_extraction(
|
||||||
|
scenario=scenario,
|
||||||
|
domain=domain,
|
||||||
|
max_classes=max_classes,
|
||||||
|
llm_temperature=llm_temperature,
|
||||||
|
llm_max_tokens=llm_max_tokens,
|
||||||
|
),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
llm_duration = time.time() - llm_start_time
|
||||||
|
logger.error(
|
||||||
|
f"LLM extraction timed out after {timeout} seconds "
|
||||||
|
f"(actual duration: {llm_duration:.2f}s)"
|
||||||
|
)
|
||||||
|
# Return empty response on timeout
|
||||||
|
return OntologyExtractionResponse(
|
||||||
|
classes=[],
|
||||||
|
domain=domain or "Unknown",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No timeout specified, call directly
|
||||||
|
response = await self._call_llm_for_extraction(
|
||||||
|
scenario=scenario,
|
||||||
|
domain=domain,
|
||||||
|
max_classes=max_classes,
|
||||||
|
llm_temperature=llm_temperature,
|
||||||
|
llm_max_tokens=llm_max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_duration = time.time() - llm_start_time
|
||||||
|
logger.info(
|
||||||
|
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: First-layer validation (string validation and cleaning)
|
||||||
|
logger.info("Step 2: Performing first-layer validation (string validation)")
|
||||||
|
validation_start_time = time.time()
|
||||||
|
|
||||||
|
response = self._validate_and_clean(
|
||||||
|
response=response,
|
||||||
|
max_description_length=max_description_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_duration = time.time() - validation_start_time
|
||||||
|
logger.info(
|
||||||
|
f"After first-layer validation: {len(response.classes)} classes remain "
|
||||||
|
f"(validation took {validation_duration:.2f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we have enough classes after first-layer validation
|
||||||
|
if len(response.classes) < min_classes:
|
||||||
|
logger.warning(
|
||||||
|
f"Only {len(response.classes)} classes remain after validation, "
|
||||||
|
f"which is below minimum of {min_classes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Second-layer validation (OWL semantic validation)
|
||||||
|
if enable_owl_validation and response.classes:
|
||||||
|
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
||||||
|
owl_start_time = time.time()
|
||||||
|
|
||||||
|
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
||||||
|
classes=response.classes,
|
||||||
|
)
|
||||||
|
|
||||||
|
owl_duration = time.time() - owl_start_time
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(
|
||||||
|
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter invalid classes based on errors
|
||||||
|
response = self._filter_invalid_classes(
|
||||||
|
response=response,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"After second-layer validation: {len(response.classes)} classes remain"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
||||||
|
else:
|
||||||
|
if not enable_owl_validation:
|
||||||
|
logger.info("Step 3: OWL validation disabled, skipping")
|
||||||
|
else:
|
||||||
|
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
||||||
|
|
||||||
|
# Calculate total duration
|
||||||
|
total_duration = time.time() - start_time
|
||||||
|
|
||||||
|
# Log extraction statistics
|
||||||
|
logger.info(
|
||||||
|
f"Ontology extraction completed - "
|
||||||
|
f"final_class_count={len(response.classes)}, "
|
||||||
|
f"domain={response.domain}, "
|
||||||
|
f"total_duration={total_duration:.2f}s, "
|
||||||
|
f"llm_duration={llm_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Re-raise timeout errors
|
||||||
|
total_duration = time.time() - start_time
|
||||||
|
logger.error(
|
||||||
|
f"Ontology extraction timed out after {timeout} seconds "
|
||||||
|
f"(total duration: {total_duration:.2f}s)",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
total_duration = time.time() - start_time
|
||||||
|
logger.error(
|
||||||
|
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Return empty response on failure
|
||||||
|
return OntologyExtractionResponse(
|
||||||
|
classes=[],
|
||||||
|
domain=domain or "Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _call_llm_for_extraction(
|
||||||
|
self,
|
||||||
|
scenario: str,
|
||||||
|
domain: Optional[str],
|
||||||
|
max_classes: int,
|
||||||
|
llm_temperature: float,
|
||||||
|
llm_max_tokens: int,
|
||||||
|
) -> OntologyExtractionResponse:
|
||||||
|
"""Call LLM to extract ontology classes from scenario.
|
||||||
|
|
||||||
|
This method renders the extraction prompt using the Jinja2 template
|
||||||
|
and calls the LLM with structured output to get ontology classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenario: Scenario description text
|
||||||
|
domain: Optional domain hint
|
||||||
|
max_classes: Maximum number of classes to extract
|
||||||
|
llm_temperature: LLM temperature parameter
|
||||||
|
llm_max_tokens: LLM max tokens parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OntologyExtractionResponse from LLM
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If LLM call fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Render prompt using template
|
||||||
|
prompt_content = await render_ontology_extraction_prompt(
|
||||||
|
scenario=scenario,
|
||||||
|
domain=domain,
|
||||||
|
max_classes=max_classes,
|
||||||
|
json_schema=OntologyExtractionResponse.model_json_schema(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
||||||
|
|
||||||
|
# Create messages for LLM
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are an expert ontology engineer specializing in knowledge "
|
||||||
|
"representation and OWL standards. Extract ontology classes from "
|
||||||
|
"scenario descriptions following the provided instructions. "
|
||||||
|
"Return valid JSON conforming to the schema."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt_content,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call LLM with structured output
|
||||||
|
logger.debug(
|
||||||
|
f"Calling LLM with temperature={llm_temperature}, "
|
||||||
|
f"max_tokens={llm_max_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.llm_client.response_structured(
|
||||||
|
messages=messages,
|
||||||
|
response_model=OntologyExtractionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"LLM extraction failed: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _validate_and_clean(
|
||||||
|
self,
|
||||||
|
response: OntologyExtractionResponse,
|
||||||
|
max_description_length: int,
|
||||||
|
) -> OntologyExtractionResponse:
|
||||||
|
"""Perform first-layer validation: string validation and cleaning.
|
||||||
|
|
||||||
|
This method validates and cleans the extracted ontology classes:
|
||||||
|
1. Validate class names (PascalCase, no reserved words)
|
||||||
|
2. Sanitize invalid class names
|
||||||
|
3. Truncate long descriptions
|
||||||
|
4. Remove duplicate classes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OntologyExtractionResponse from LLM
|
||||||
|
max_description_length: Maximum description length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned OntologyExtractionResponse
|
||||||
|
"""
|
||||||
|
if not response.classes:
|
||||||
|
logger.debug("No classes to validate")
|
||||||
|
return response
|
||||||
|
|
||||||
|
logger.debug(f"Validating {len(response.classes)} classes")
|
||||||
|
|
||||||
|
validated_classes = []
|
||||||
|
|
||||||
|
for ontology_class in response.classes:
|
||||||
|
# Validate class name
|
||||||
|
is_valid, error_msg = self.validator.validate_class_name(
|
||||||
|
ontology_class.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to sanitize
|
||||||
|
sanitized_name = self.validator.sanitize_class_name(
|
||||||
|
ontology_class.name
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update class name
|
||||||
|
ontology_class.name = sanitized_name
|
||||||
|
|
||||||
|
# Re-validate sanitized name
|
||||||
|
is_valid, error_msg = self.validator.validate_class_name(
|
||||||
|
sanitized_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
||||||
|
"Skipping this class."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Truncate description if too long
|
||||||
|
if ontology_class.description:
|
||||||
|
original_length = len(ontology_class.description)
|
||||||
|
ontology_class.description = self.validator.truncate_description(
|
||||||
|
ontology_class.description,
|
||||||
|
max_length=max_description_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(ontology_class.description) < original_length:
|
||||||
|
logger.debug(
|
||||||
|
f"Truncated description for '{ontology_class.name}': "
|
||||||
|
f"{original_length} -> {len(ontology_class.description)} chars"
|
||||||
|
)
|
||||||
|
|
||||||
|
validated_classes.append(ontology_class)
|
||||||
|
|
||||||
|
# Remove duplicates (case-insensitive)
|
||||||
|
original_count = len(validated_classes)
|
||||||
|
validated_classes = self.validator.remove_duplicates(validated_classes)
|
||||||
|
|
||||||
|
if len(validated_classes) < original_count:
|
||||||
|
logger.info(
|
||||||
|
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return cleaned response
|
||||||
|
return OntologyExtractionResponse(
|
||||||
|
classes=validated_classes,
|
||||||
|
domain=response.domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_invalid_classes(
|
||||||
|
self,
|
||||||
|
response: OntologyExtractionResponse,
|
||||||
|
errors: List[str],
|
||||||
|
) -> OntologyExtractionResponse:
|
||||||
|
"""Filter invalid classes based on OWL validation errors.
|
||||||
|
|
||||||
|
This method analyzes OWL validation errors and removes classes
|
||||||
|
that caused validation failures (e.g., circular inheritance,
|
||||||
|
inconsistencies).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OntologyExtractionResponse to filter
|
||||||
|
errors: List of error messages from OWL validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered OntologyExtractionResponse
|
||||||
|
"""
|
||||||
|
if not errors:
|
||||||
|
return response
|
||||||
|
|
||||||
|
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
||||||
|
|
||||||
|
# Extract class names mentioned in errors
|
||||||
|
invalid_class_names = set()
|
||||||
|
|
||||||
|
for error in errors:
|
||||||
|
# Look for class names in error messages
|
||||||
|
for ontology_class in response.classes:
|
||||||
|
if ontology_class.name in error:
|
||||||
|
invalid_class_names.add(ontology_class.name)
|
||||||
|
logger.debug(
|
||||||
|
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out invalid classes
|
||||||
|
if invalid_class_names:
|
||||||
|
original_count = len(response.classes)
|
||||||
|
|
||||||
|
filtered_classes = [
|
||||||
|
c for c in response.classes
|
||||||
|
if c.name not in invalid_class_names
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
||||||
|
f"{invalid_class_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return OntologyExtractionResponse(
|
||||||
|
classes=filtered_classes,
|
||||||
|
domain=response.domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -25,6 +25,15 @@ class TripletExtractor:
|
|||||||
"""
|
"""
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
|
|
||||||
|
def _get_language(self) -> str:
|
||||||
|
"""Get the configured language for entity descriptions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language code ("zh" or "en")
|
||||||
|
"""
|
||||||
|
from app.core.config import settings
|
||||||
|
return settings.DEFAULT_LANGUAGE
|
||||||
|
|
||||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||||
"""Process a single statement and return extracted triplets and entities"""
|
"""Process a single statement and return extracted triplets and entities"""
|
||||||
# Render the prompt using helper function
|
# Render the prompt using helper function
|
||||||
@@ -40,7 +49,8 @@ class TripletExtractor:
|
|||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
chunk_content=chunk_content,
|
chunk_content=chunk_content,
|
||||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||||
predicate_instructions=PREDICATE_DEFINITIONS
|
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||||
|
language=self._get_language()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create messages for LLM
|
# Create messages for LLM
|
||||||
|
|||||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
|||||||
key=lambda eid: (
|
key=lambda eid: (
|
||||||
_strength_rank(eid),
|
_strength_rank(eid),
|
||||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||||
|
0 # 临时占位
|
||||||
),
|
),
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
|
|||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# entity_a: Dict of entity A attributes
|
# entity_a: Dict of entity A attributes
|
||||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||||
|
|
||||||
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
chunk_content: The content of the chunk to process
|
chunk_content: The content of the chunk to process
|
||||||
json_schema: JSON schema for the expected output format
|
json_schema: JSON schema for the expected output format
|
||||||
predicate_instructions: Optional predicate instructions
|
predicate_instructions: Optional predicate instructions
|
||||||
|
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
statement=statement,
|
statement=statement,
|
||||||
chunk_content=chunk_content,
|
chunk_content=chunk_content,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
predicate_instructions=predicate_instructions
|
predicate_instructions=predicate_instructions,
|
||||||
|
language=language
|
||||||
)
|
)
|
||||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||||
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
'statement': 'str',
|
'statement': 'str',
|
||||||
'chunk_content': 'str',
|
'chunk_content': 'str',
|
||||||
'json_schema': 'TripletExtractionResponse.schema',
|
'json_schema': 'TripletExtractionResponse.schema',
|
||||||
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
||||||
|
'language': language
|
||||||
})
|
})
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts: str,
|
chunk_texts: str,
|
||||||
json_schema: dict,
|
json_schema: dict,
|
||||||
max_words: int = 200,
|
max_words: int = 200,
|
||||||
|
language: str = "zh",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||||
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts: Concatenated text of conversation chunks
|
chunk_texts: Concatenated text of conversation chunks
|
||||||
json_schema: JSON schema for the expected output format
|
json_schema: JSON schema for the expected output format
|
||||||
max_words: Maximum words for the summary
|
max_words: Maximum words for the summary
|
||||||
|
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string.
|
Rendered prompt content as string.
|
||||||
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts=chunk_texts,
|
chunk_texts=chunk_texts,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
max_words=max_words,
|
max_words=max_words,
|
||||||
|
language=language,
|
||||||
)
|
)
|
||||||
log_prompt_rendering('memory summary', rendered_prompt)
|
log_prompt_rendering('memory summary', rendered_prompt)
|
||||||
log_template_rendering('memory_summary.jinja2', {
|
log_template_rendering('memory_summary.jinja2', {
|
||||||
'chunk_texts_len': len(chunk_texts or ""),
|
'chunk_texts_len': len(chunk_texts or ""),
|
||||||
'max_words': max_words,
|
'max_words': max_words,
|
||||||
'json_schema': 'MemorySummaryResponse.schema'
|
'json_schema': 'MemorySummaryResponse.schema',
|
||||||
|
'language': language
|
||||||
})
|
})
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
|
|||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
|
|
||||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The content of the episodic memory summary to analyze
|
content: The content of the episodic memory summary to analyze
|
||||||
|
language: The language to use for title generation ("zh" for Chinese, "en" for English)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
"""
|
"""
|
||||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||||
rendered_prompt = template.render(content=content)
|
rendered_prompt = template.render(content=content, language=language)
|
||||||
|
|
||||||
# 记录渲染结果到提示日志
|
# 记录渲染结果到提示日志
|
||||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||||
# 可选:记录模板渲染信息
|
# 可选:记录模板渲染信息
|
||||||
log_template_rendering('episodic_type_classification.jinja2', {
|
log_template_rendering('episodic_type_classification.jinja2', {
|
||||||
'content_len': len(content) if content else 0
|
'content_len': len(content) if content else 0,
|
||||||
|
'language': language
|
||||||
|
})
|
||||||
|
|
||||||
|
return rendered_prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def render_ontology_extraction_prompt(
|
||||||
|
scenario: str,
|
||||||
|
domain: str | None = None,
|
||||||
|
max_classes: int = 15,
|
||||||
|
json_schema: dict | None = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenario: The scenario description text to extract ontology classes from
|
||||||
|
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
|
||||||
|
max_classes: Maximum number of classes to extract (default: 15)
|
||||||
|
json_schema: JSON schema for the expected output format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered prompt content as string
|
||||||
|
"""
|
||||||
|
template = prompt_env.get_template("extract_ontology.jinja2")
|
||||||
|
rendered_prompt = template.render(
|
||||||
|
scenario=scenario,
|
||||||
|
domain=domain,
|
||||||
|
max_classes=max_classes,
|
||||||
|
json_schema=json_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录渲染结果到提示日志
|
||||||
|
log_prompt_rendering('ontology extraction', rendered_prompt)
|
||||||
|
# 可选:记录模板渲染信息
|
||||||
|
log_template_rendering('extract_ontology.jinja2', {
|
||||||
|
'scenario_len': len(scenario) if scenario else 0,
|
||||||
|
'domain': domain,
|
||||||
|
'max_classes': max_classes,
|
||||||
|
'json_schema': 'OntologyExtractionResponse.schema'
|
||||||
})
|
})
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_a.description | default('') }}"
|
- 描述: "{{ entity_a.description | default('') }}"
|
||||||
- 别名: {{ entity_a.aliases | default([]) }}
|
- 别名: {{ entity_a.aliases | default([]) }}
|
||||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||||
|
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
||||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||||
|
|
||||||
实体B:
|
实体B:
|
||||||
@@ -17,7 +18,8 @@
|
|||||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_b.description | default('') }}"
|
- 描述: "{{ entity_b.description | default('') }}"
|
||||||
- 别名: {{ entity_b.aliases | default([]) }}
|
- 别名: {{ entity_b.aliases | default([]) }}
|
||||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||||
|
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
||||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||||
|
|
||||||
上下文:
|
上下文:
|
||||||
|
|||||||
@@ -1,8 +1,19 @@
|
|||||||
=== Task ===
|
=== Task ===
|
||||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||||
|
|
||||||
|
{% if language == "zh" %}
|
||||||
|
**重要:请使用中文生成标题和分类。**
|
||||||
|
{% else %}
|
||||||
|
**Important: Please generate the title and classification in English.**
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
=== Requirements ===
|
=== Requirements ===
|
||||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||||
|
{% if language == "zh" %}
|
||||||
|
- 标题必须使用中文
|
||||||
|
{% else %}
|
||||||
|
- Title must be in English
|
||||||
|
{% endif %}
|
||||||
- Classify into exactly one category based on the primary theme
|
- Classify into exactly one category based on the primary theme
|
||||||
- Be specific and avoid ambiguity
|
- Be specific and avoid ambiguity
|
||||||
- Output must be valid JSON conforming to the schema below
|
- Output must be valid JSON conforming to the schema below
|
||||||
|
|||||||
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
===Task===
|
||||||
|
Extract ontology classes from the given scenario description following ontology engineering standards.
|
||||||
|
|
||||||
|
===Role===
|
||||||
|
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
|
||||||
|
|
||||||
|
===Scenario Description===
|
||||||
|
{{ scenario }}
|
||||||
|
|
||||||
|
{% if domain -%}
|
||||||
|
===Domain Hint===
|
||||||
|
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
===Extraction Rules===
|
||||||
|
|
||||||
|
**1. Abstract Classes, Not Instances:**
|
||||||
|
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
|
||||||
|
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
|
||||||
|
- Think in terms of "types of things" rather than "specific things"
|
||||||
|
|
||||||
|
**2. Naming Convention (PascalCase):**
|
||||||
|
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
|
||||||
|
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
|
||||||
|
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
|
||||||
|
- Use clear, descriptive names in English
|
||||||
|
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
|
||||||
|
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
|
||||||
|
|
||||||
|
**3. Domain Relevance:**
|
||||||
|
- Focus on classes that are central to the scenario's domain
|
||||||
|
- Prioritize classes that represent key concepts, entities, or relationships
|
||||||
|
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
|
||||||
|
|
||||||
|
**4. Class Quantity:**
|
||||||
|
- Extract between 5 and {{ max_classes }} classes
|
||||||
|
- Aim for a balanced set covering the main concepts in the scenario
|
||||||
|
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
||||||
|
|
||||||
|
**5. Clear Descriptions:**
|
||||||
|
- Provide concise, informative descriptions in Chinese (max 500 characters)
|
||||||
|
- Describe what the class represents, not specific instances
|
||||||
|
- Use clear, natural Chinese language that explains the class's role in the domain
|
||||||
|
|
||||||
|
**6. Concrete Examples:**
|
||||||
|
- Provide 2-5 concrete instance examples in Chinese for each class
|
||||||
|
- Examples should be specific, realistic instances of the class
|
||||||
|
- Examples help clarify the class's scope and meaning
|
||||||
|
- Use natural Chinese language for examples
|
||||||
|
- Example format: ["示例1", "示例2", "示例3"]
|
||||||
|
|
||||||
|
**7. Class Hierarchy:**
|
||||||
|
- Identify parent-child relationships where applicable
|
||||||
|
- Use the parent_class field to specify inheritance
|
||||||
|
- Parent class must be one of the extracted classes or a standard OWL class
|
||||||
|
- Leave parent_class as null for top-level classes
|
||||||
|
|
||||||
|
**8. Entity Types:**
|
||||||
|
- Classify each class with an appropriate entity_type
|
||||||
|
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
|
||||||
|
- Choose the most specific type that applies
|
||||||
|
|
||||||
|
**9. OWL Reserved Words:**
|
||||||
|
- Do NOT use OWL reserved words as class names
|
||||||
|
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
|
||||||
|
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
|
||||||
|
|
||||||
|
**10. Language Consistency:**
|
||||||
|
- Extract all class names in English (PascalCase format) for the "name" field
|
||||||
|
- Provide Chinese translation for class names in the "name_chinese" field
|
||||||
|
- Descriptions MUST be in Chinese (中文)
|
||||||
|
- Examples MUST be in Chinese (中文)
|
||||||
|
- Use clear, natural Chinese language for descriptions and examples
|
||||||
|
|
||||||
|
===Examples===
|
||||||
|
|
||||||
|
**Example 1 (Healthcare Domain):**
|
||||||
|
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{
|
||||||
|
"classes": [
|
||||||
|
{
|
||||||
|
"name": "Patient",
|
||||||
|
"name_chinese": "患者",
|
||||||
|
"description": "在医疗机构接受医疗护理或治疗的人",
|
||||||
|
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Person",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MedicalProcedure",
|
||||||
|
"name_chinese": "医疗程序",
|
||||||
|
"description": "为医疗诊断或治疗而执行的系统性操作流程",
|
||||||
|
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Process",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Diagnosis",
|
||||||
|
"name_chinese": "诊断",
|
||||||
|
"description": "基于症状和检查结果对疾病或状况的识别",
|
||||||
|
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Concept",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Doctor",
|
||||||
|
"name_chinese": "医生",
|
||||||
|
"description": "诊断和治疗患者的持证医疗专业人员",
|
||||||
|
"examples": ["全科医生", "外科医生", "心脏病专家"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Role",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Treatment",
|
||||||
|
"name_chinese": "治疗",
|
||||||
|
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
|
||||||
|
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Process",
|
||||||
|
"domain": "Healthcare"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"domain": "Healthcare",
|
||||||
|
"namespace": "http://example.org/healthcare#"
|
||||||
|
}
|
||||||
|
|
||||||
|
**Example 2 (Education Domain):**
|
||||||
|
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{
|
||||||
|
"classes": [
|
||||||
|
{
|
||||||
|
"name": "Student",
|
||||||
|
"name_chinese": "学生",
|
||||||
|
"description": "在教育机构注册学习的人",
|
||||||
|
"examples": ["本科生", "研究生", "在职学生"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Role",
|
||||||
|
"domain": "Education"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Course",
|
||||||
|
"name_chinese": "课程",
|
||||||
|
"description": "涵盖特定学科或主题的结构化教育课程",
|
||||||
|
"examples": ["计算机科学导论", "微积分I", "世界历史"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Concept",
|
||||||
|
"domain": "Education"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Professor",
|
||||||
|
"name_chinese": "教授",
|
||||||
|
"description": "教授课程并进行研究的学术教师",
|
||||||
|
"examples": ["助理教授", "副教授", "正教授"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Role",
|
||||||
|
"domain": "Education"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AcademicProgram",
|
||||||
|
"name_chinese": "学术项目",
|
||||||
|
"description": "通向学位或证书的结构化课程体系",
|
||||||
|
"examples": ["理学学士", "文学硕士", "博士项目"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Concept",
|
||||||
|
"domain": "Education"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Assignment",
|
||||||
|
"name_chinese": "作业",
|
||||||
|
"description": "分配给学生以评估学习成果的任务或项目",
|
||||||
|
"examples": ["论文", "习题集", "研究报告", "实验报告"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Object",
|
||||||
|
"domain": "Education"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Lecture",
|
||||||
|
"name_chinese": "讲座",
|
||||||
|
"description": "由教师进行的教育性演讲或讲座",
|
||||||
|
"examples": ["入门讲座", "客座讲座", "在线讲座"],
|
||||||
|
"parent_class": null,
|
||||||
|
"entity_type": "Event",
|
||||||
|
"domain": "Education"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"domain": "Education",
|
||||||
|
"namespace": "http://example.org/education#"
|
||||||
|
}
|
||||||
|
|
||||||
|
===Output Format===
|
||||||
|
|
||||||
|
**JSON Requirements:**
|
||||||
|
- Use only ASCII double quotes (") for JSON structure
|
||||||
|
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||||
|
- Escape quotation marks in text with backslashes (\")
|
||||||
|
- Ensure proper string closure and comma separation
|
||||||
|
- No line breaks within JSON string values
|
||||||
|
- All class names must be in PascalCase format
|
||||||
|
- All class names must be unique (case-insensitive)
|
||||||
|
- Extract between 5 and {{ max_classes }} classes
|
||||||
|
|
||||||
|
{{ json_schema }}
|
||||||
@@ -5,6 +5,12 @@
|
|||||||
===Task===
|
===Task===
|
||||||
Extract entities and knowledge triplets from the given statement.
|
Extract entities and knowledge triplets from the given statement.
|
||||||
|
|
||||||
|
{% if language == "zh" %}
|
||||||
|
**重要:请使用中文生成实体描述(description)和示例(example)。**
|
||||||
|
{% else %}
|
||||||
|
**Important: Please generate entity descriptions and examples in English.**
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
===Inputs===
|
===Inputs===
|
||||||
**Chunk Content:** "{{ chunk_content }}"
|
**Chunk Content:** "{{ chunk_content }}"
|
||||||
**Statement:** "{{ statement }}"
|
**Statement:** "{{ statement }}"
|
||||||
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
|
|||||||
|
|
||||||
**Entity Extraction:**
|
**Entity Extraction:**
|
||||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||||
|
{% if language == "zh" %}
|
||||||
|
- **实体描述(description)必须使用中文**
|
||||||
|
- **示例(example)必须使用中文**
|
||||||
|
{% else %}
|
||||||
|
- **Entity descriptions must be in English**
|
||||||
|
- **Examples must be in English**
|
||||||
|
{% endif %}
|
||||||
- **Semantic Memory Classification (is_explicit_memory):**
|
- **Semantic Memory Classification (is_explicit_memory):**
|
||||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||||
@@ -334,9 +347,11 @@ Output:
|
|||||||
- Escape quotation marks in text with backslashes (\")
|
- Escape quotation marks in text with backslashes (\")
|
||||||
- Ensure proper string closure and comma separation
|
- Ensure proper string closure and comma separation
|
||||||
- No line breaks within JSON string values
|
- No line breaks within JSON string values
|
||||||
- The output language should ALWAYS match the input language
|
{% if language == "zh" %}
|
||||||
- If input is in English, extract statements in English
|
- **语言要求:实体描述(description)和示例(example)必须使用中文**
|
||||||
- If input is in Chinese, extract statements in Chinese
|
{% else %}
|
||||||
|
- **Language Requirement: Entity descriptions and examples must be in English**
|
||||||
|
{% endif %}
|
||||||
- Preserve the original language and do not translate
|
- Preserve the original language and do not translate
|
||||||
|
|
||||||
{{ json_schema }}
|
{{ json_schema }}
|
||||||
@@ -5,10 +5,21 @@
|
|||||||
=== Task ===
|
=== Task ===
|
||||||
Summarize the provided conversation chunks into a concise Memory summary.
|
Summarize the provided conversation chunks into a concise Memory summary.
|
||||||
|
|
||||||
|
{% if language == "zh" %}
|
||||||
|
**重要:请使用中文生成摘要内容。**
|
||||||
|
{% else %}
|
||||||
|
**Important: Please generate the summary content in English.**
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
=== Requirements ===
|
=== Requirements ===
|
||||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||||
- Avoid repetition and filler; be specific.
|
- Avoid repetition and filler; be specific.
|
||||||
- Keep it under {{ max_words or 200 }} words.
|
- Keep it under {{ max_words or 200 }} words.
|
||||||
|
{% if language == "zh" %}
|
||||||
|
- 摘要内容必须使用中文
|
||||||
|
{% else %}
|
||||||
|
- Summary content must be in English
|
||||||
|
{% endif %}
|
||||||
- Output must be valid JSON conforming to the schema below.
|
- Output must be valid JSON conforming to the schema below.
|
||||||
|
|
||||||
=== Input ===
|
=== Input ===
|
||||||
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
|
|||||||
4. Do not include line breaks within JSON string values
|
4. Do not include line breaks within JSON string values
|
||||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||||
|
|
||||||
The output language should always be the same as the input language.
|
{% if language == "zh" %}
|
||||||
|
**语言要求:输出内容必须使用中文。**
|
||||||
|
{% else %}
|
||||||
|
**Language Requirement: The output content must be in English.**
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||||
{{ json_schema }}
|
{{ json_schema }}
|
||||||
10
api/app/core/memory/utils/validation/__init__.py
Normal file
10
api/app/core/memory/utils/validation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""Validation utilities for ontology extraction.
|
||||||
|
|
||||||
|
This module provides validation classes for ontology class names,
|
||||||
|
descriptions, and OWL compliance checking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .ontology_validator import OntologyValidator
|
||||||
|
from .owl_validator import OWLValidator
|
||||||
|
|
||||||
|
__all__ = ['OntologyValidator', 'OWLValidator']
|
||||||
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""String validation for ontology class names and descriptions.
|
||||||
|
|
||||||
|
This module provides the OntologyValidator class for validating and sanitizing
|
||||||
|
ontology class names according to OWL standards and naming conventions.
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from app.core.memory.models.ontology_models import OntologyClass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyValidator:
|
||||||
|
"""Validator for ontology class names and descriptions.
|
||||||
|
|
||||||
|
This validator performs string-level validation including:
|
||||||
|
- PascalCase naming convention validation
|
||||||
|
- OWL reserved word checking
|
||||||
|
- Duplicate class name removal
|
||||||
|
- Description length truncation
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
|
||||||
|
"""
|
||||||
|
|
||||||
|
# OWL reserved words that cannot be used as class names
|
||||||
|
OWL_RESERVED_WORDS = {
|
||||||
|
'Thing', 'Nothing', 'Class', 'Property',
|
||||||
|
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
|
||||||
|
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
|
||||||
|
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
|
||||||
|
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
|
||||||
|
'Annotation', 'AnnotationProperty', 'Axiom',
|
||||||
|
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
|
||||||
|
'Datatype', 'DataRange', 'Literal',
|
||||||
|
'DeprecatedClass', 'DeprecatedProperty',
|
||||||
|
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
|
||||||
|
'BackwardCompatibleWith', 'OntologyProperty',
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_class_name(self, name: str) -> Tuple[bool, str]:
|
||||||
|
"""Validate that a class name follows OWL naming conventions.
|
||||||
|
|
||||||
|
Validation rules:
|
||||||
|
1. Must not be empty
|
||||||
|
2. Must start with an uppercase letter (PascalCase)
|
||||||
|
3. Cannot contain spaces
|
||||||
|
4. Can only contain alphanumeric characters and underscores
|
||||||
|
5. Cannot be an OWL reserved word
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The class name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
- is_valid: True if the name is valid, False otherwise
|
||||||
|
- error_message: Empty string if valid, error description if invalid
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OntologyValidator()
|
||||||
|
>>> validator.validate_class_name("MedicalProcedure")
|
||||||
|
(True, "")
|
||||||
|
>>> validator.validate_class_name("medical procedure")
|
||||||
|
(False, "Class name 'medical procedure' cannot contain spaces")
|
||||||
|
>>> validator.validate_class_name("Thing")
|
||||||
|
(False, "Class name 'Thing' is an OWL reserved word")
|
||||||
|
"""
|
||||||
|
logger.debug(f"Validating class name: '{name}'")
|
||||||
|
|
||||||
|
# Check if empty
|
||||||
|
if not name or not name.strip():
|
||||||
|
error_msg = "Class name cannot be empty"
|
||||||
|
logger.warning(f"Validation failed: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
name = name.strip()
|
||||||
|
|
||||||
|
# Check if it's an OWL reserved word
|
||||||
|
if name in self.OWL_RESERVED_WORDS:
|
||||||
|
error_msg = f"Class name '{name}' is an OWL reserved word"
|
||||||
|
logger.warning(f"Validation failed: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# Check if starts with uppercase letter
|
||||||
|
if not name[0].isupper():
|
||||||
|
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
|
||||||
|
logger.warning(f"Validation failed: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# Check for spaces
|
||||||
|
if ' ' in name:
|
||||||
|
error_msg = f"Class name '{name}' cannot contain spaces"
|
||||||
|
logger.warning(f"Validation failed: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# Check for invalid characters (only alphanumeric and underscore allowed)
|
||||||
|
if not re.match(r'^[A-Za-z0-9_]+$', name):
|
||||||
|
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
|
||||||
|
logger.warning(f"Validation failed: {error_msg}")
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
logger.debug(f"Class name '{name}' is valid")
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def sanitize_class_name(self, name: str) -> str:
|
||||||
|
"""Attempt to sanitize an invalid class name into a valid format.
|
||||||
|
|
||||||
|
Sanitization steps:
|
||||||
|
1. Strip whitespace
|
||||||
|
2. Remove invalid characters
|
||||||
|
3. Replace spaces with empty string (PascalCase)
|
||||||
|
4. Capitalize first letter of each word
|
||||||
|
5. If result is empty or starts with number, prefix with 'Class'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The class name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized class name that should pass validation
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OntologyValidator()
|
||||||
|
>>> validator.sanitize_class_name("medical procedure")
|
||||||
|
'MedicalProcedure'
|
||||||
|
>>> validator.sanitize_class_name("patient-record")
|
||||||
|
'PatientRecord'
|
||||||
|
>>> validator.sanitize_class_name("123invalid")
|
||||||
|
'Class123Invalid'
|
||||||
|
"""
|
||||||
|
logger.debug(f"Sanitizing class name: '{name}'")
|
||||||
|
|
||||||
|
if not name or not name.strip():
|
||||||
|
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
|
||||||
|
return "UnnamedClass"
|
||||||
|
|
||||||
|
# Strip whitespace
|
||||||
|
name = name.strip()
|
||||||
|
original_name = name
|
||||||
|
|
||||||
|
# Split on spaces, hyphens, and underscores, then capitalize each word
|
||||||
|
words = re.split(r'[\s\-_]+', name)
|
||||||
|
|
||||||
|
# Capitalize first letter of each word and keep rest as is
|
||||||
|
sanitized_words = []
|
||||||
|
for word in words:
|
||||||
|
if word:
|
||||||
|
# Remove non-alphanumeric characters except underscore
|
||||||
|
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
|
||||||
|
if clean_word:
|
||||||
|
# Capitalize first letter
|
||||||
|
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
|
||||||
|
|
||||||
|
# Join words
|
||||||
|
sanitized = ''.join(sanitized_words)
|
||||||
|
|
||||||
|
# If empty or starts with number, prefix with 'Class'
|
||||||
|
if not sanitized or sanitized[0].isdigit():
|
||||||
|
sanitized = 'Class' + sanitized
|
||||||
|
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
|
||||||
|
|
||||||
|
# If it's a reserved word, append 'Class' suffix
|
||||||
|
if sanitized in self.OWL_RESERVED_WORDS:
|
||||||
|
sanitized = sanitized + 'Class'
|
||||||
|
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
|
||||||
|
|
||||||
|
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
|
||||||
|
"""Remove duplicate ontology classes based on case-insensitive name comparison.
|
||||||
|
|
||||||
|
When duplicates are found, keeps the first occurrence and discards subsequent ones.
|
||||||
|
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes: List of OntologyClass objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OntologyClass objects with duplicates removed
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OntologyValidator()
|
||||||
|
>>> classes = [
|
||||||
|
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||||
|
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
|
||||||
|
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||||
|
... ]
|
||||||
|
>>> unique = validator.remove_duplicates(classes)
|
||||||
|
>>> len(unique)
|
||||||
|
2
|
||||||
|
>>> [c.name for c in unique]
|
||||||
|
['Patient', 'Doctor']
|
||||||
|
"""
|
||||||
|
if not classes:
|
||||||
|
logger.debug("No classes to check for duplicates")
|
||||||
|
return classes
|
||||||
|
|
||||||
|
logger.debug(f"Checking {len(classes)} classes for duplicates")
|
||||||
|
|
||||||
|
seen_names = set()
|
||||||
|
unique_classes = []
|
||||||
|
duplicates_found = []
|
||||||
|
|
||||||
|
for ontology_class in classes:
|
||||||
|
# Use lowercase for comparison
|
||||||
|
name_lower = ontology_class.name.lower()
|
||||||
|
|
||||||
|
if name_lower not in seen_names:
|
||||||
|
seen_names.add(name_lower)
|
||||||
|
unique_classes.append(ontology_class)
|
||||||
|
else:
|
||||||
|
duplicates_found.append(ontology_class.name)
|
||||||
|
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
|
||||||
|
|
||||||
|
if duplicates_found:
|
||||||
|
logger.info(
|
||||||
|
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No duplicate classes found")
|
||||||
|
|
||||||
|
return unique_classes
|
||||||
|
|
||||||
|
def truncate_description(self, description: str, max_length: int = 500) -> str:
|
||||||
|
"""Truncate a description to a maximum length.
|
||||||
|
|
||||||
|
If the description exceeds max_length, it will be truncated and
|
||||||
|
an ellipsis (...) will be appended to indicate truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: The description text to truncate
|
||||||
|
max_length: Maximum allowed length (default: 500)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated description string
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OntologyValidator()
|
||||||
|
>>> long_desc = "A" * 600
|
||||||
|
>>> truncated = validator.truncate_description(long_desc, max_length=500)
|
||||||
|
>>> len(truncated)
|
||||||
|
500
|
||||||
|
>>> truncated.endswith("...")
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
if not description:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if len(description) <= max_length:
|
||||||
|
return description
|
||||||
|
|
||||||
|
# Truncate and add ellipsis
|
||||||
|
# Reserve 3 characters for "..."
|
||||||
|
truncate_at = max_length - 3
|
||||||
|
truncated = description[:truncate_at] + "..."
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Truncated description from {len(description)} to {len(truncated)} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated
|
||||||
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
@@ -0,0 +1,585 @@
|
|||||||
|
"""OWL semantic validation for ontology classes using Owlready2.
|
||||||
|
|
||||||
|
This module provides the OWLValidator class for validating ontology classes
|
||||||
|
against OWL standards using the Owlready2 library. It performs semantic
|
||||||
|
validation including consistency checking, circular inheritance detection,
|
||||||
|
and OWL file export.
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from owlready2 import (
|
||||||
|
World,
|
||||||
|
Thing,
|
||||||
|
get_ontology,
|
||||||
|
sync_reasoner_pellet,
|
||||||
|
OwlReadyInconsistentOntologyError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.memory.models.ontology_models import OntologyClass
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OWLValidator:
|
||||||
|
"""Validator for OWL semantic validation of ontology classes.
|
||||||
|
|
||||||
|
This validator performs semantic-level validation using Owlready2 including:
|
||||||
|
- Creating OWL classes from ontology class definitions
|
||||||
|
- Running consistency checking with Pellet reasoner
|
||||||
|
- Detecting circular inheritance
|
||||||
|
- Validating Protégé compatibility
|
||||||
|
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_namespace: Base URI for the ontology namespace
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
|
||||||
|
"""Initialize the OWL validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
|
||||||
|
"""
|
||||||
|
self.base_namespace = base_namespace
|
||||||
|
|
||||||
|
def validate_ontology_classes(
|
||||||
|
self,
|
||||||
|
classes: List[OntologyClass],
|
||||||
|
) -> Tuple[bool, List[str], Optional[World]]:
|
||||||
|
"""Validate extracted ontology classes against OWL standards.
|
||||||
|
|
||||||
|
This method creates an OWL ontology from the provided classes using Owlready2,
|
||||||
|
runs consistency checking with the Pellet reasoner, and detects common issues
|
||||||
|
like circular inheritance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes: List of OntologyClass objects to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_messages, world):
|
||||||
|
- is_valid: True if ontology is valid and consistent, False otherwise
|
||||||
|
- error_messages: List of error/warning messages
|
||||||
|
- world: Owlready2 World object containing the ontology (None if validation failed)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OWLValidator()
|
||||||
|
>>> classes = [
|
||||||
|
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||||
|
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||||
|
... ]
|
||||||
|
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||||
|
>>> is_valid
|
||||||
|
True
|
||||||
|
>>> len(errors)
|
||||||
|
0
|
||||||
|
"""
|
||||||
|
if not classes:
|
||||||
|
return False, ["No classes provided for validation"], None
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a new world (isolated ontology environment)
|
||||||
|
world = World()
|
||||||
|
|
||||||
|
# Use a proper ontology IRI
|
||||||
|
# Owlready2 expects the IRI to end with .owl or similar
|
||||||
|
onto_iri = self.base_namespace.rstrip('#/')
|
||||||
|
if not onto_iri.endswith('.owl'):
|
||||||
|
onto_iri = onto_iri + '.owl'
|
||||||
|
|
||||||
|
# Create ontology
|
||||||
|
onto = world.get_ontology(onto_iri)
|
||||||
|
|
||||||
|
with onto:
|
||||||
|
# Dictionary to store created OWL classes for parent reference
|
||||||
|
owl_classes = {}
|
||||||
|
|
||||||
|
# First pass: Create all classes without parent relationships
|
||||||
|
for ontology_class in classes:
|
||||||
|
try:
|
||||||
|
# Create OWL class dynamically using type() with Thing as base
|
||||||
|
# The key is to NOT set namespace in the dict, let Owlready2 handle it
|
||||||
|
owl_class = type(
|
||||||
|
ontology_class.name, # Class name
|
||||||
|
(Thing,), # Base classes
|
||||||
|
{} # Class dict (empty, let Owlready2 manage)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add label (rdfs:label) - include both English and Chinese names
|
||||||
|
labels = [ontology_class.name]
|
||||||
|
if ontology_class.name_chinese:
|
||||||
|
labels.append(ontology_class.name_chinese)
|
||||||
|
owl_class.label = labels
|
||||||
|
|
||||||
|
# Add comment (rdfs:comment) with description
|
||||||
|
if ontology_class.description:
|
||||||
|
owl_class.comment = [ontology_class.description]
|
||||||
|
|
||||||
|
# Store for parent relationship setup
|
||||||
|
owl_classes[ontology_class.name] = owl_class
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Created OWL class: {ontology_class.name} "
|
||||||
|
f"(Chinese: {ontology_class.name_chinese}) "
|
||||||
|
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
|
||||||
|
errors.append(error_msg)
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
|
||||||
|
# Second pass: Set up parent relationships
|
||||||
|
for ontology_class in classes:
|
||||||
|
if ontology_class.parent_class and ontology_class.name in owl_classes:
|
||||||
|
parent_name = ontology_class.parent_class
|
||||||
|
|
||||||
|
# Check if parent exists
|
||||||
|
if parent_name in owl_classes:
|
||||||
|
try:
|
||||||
|
child_class = owl_classes[ontology_class.name]
|
||||||
|
parent_class = owl_classes[parent_name]
|
||||||
|
|
||||||
|
# Set parent by modifying is_a
|
||||||
|
child_class.is_a = [parent_class]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = (
|
||||||
|
f"Failed to set parent relationship "
|
||||||
|
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
|
||||||
|
)
|
||||||
|
errors.append(error_msg)
|
||||||
|
logger.warning(error_msg)
|
||||||
|
else:
|
||||||
|
warning_msg = (
|
||||||
|
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
|
||||||
|
)
|
||||||
|
errors.append(warning_msg)
|
||||||
|
logger.warning(warning_msg)
|
||||||
|
|
||||||
|
# Check for circular inheritance
|
||||||
|
for class_name, owl_class in owl_classes.items():
|
||||||
|
if self._has_circular_inheritance(owl_class):
|
||||||
|
error_msg = f"Circular inheritance detected for class '{class_name}'"
|
||||||
|
errors.append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Run consistency checking with Pellet reasoner
|
||||||
|
try:
|
||||||
|
logger.info("Running Pellet reasoner for consistency checking...")
|
||||||
|
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
|
||||||
|
logger.info("Consistency check passed")
|
||||||
|
|
||||||
|
except OwlReadyInconsistentOntologyError as e:
|
||||||
|
error_msg = f"Ontology is inconsistent: {str(e)}"
|
||||||
|
errors.append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, errors, world
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Reasoner errors are often due to Java not being installed or configured
|
||||||
|
# Log as warning but don't fail validation - ontology structure is still valid
|
||||||
|
warning_msg = f"Reasoner check skipped: {str(e)}"
|
||||||
|
if str(e).strip(): # Only log if there's an actual error message
|
||||||
|
logger.warning(warning_msg)
|
||||||
|
else:
|
||||||
|
logger.warning("Reasoner check skipped: Java may not be installed or configured")
|
||||||
|
# Continue - ontology structure is valid even without reasoner check
|
||||||
|
|
||||||
|
# If we have errors (excluding warnings), validation failed
|
||||||
|
is_valid = len(errors) == 0
|
||||||
|
|
||||||
|
return is_valid, errors, world
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"OWL validation failed: {str(e)}"
|
||||||
|
errors.append(error_msg)
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False, errors, None
|
||||||
|
|
||||||
|
def _has_circular_inheritance(self, owl_class) -> bool:
|
||||||
|
"""Check if an OWL class has circular inheritance.
|
||||||
|
|
||||||
|
Circular inheritance occurs when a class inherits from itself through
|
||||||
|
a chain of parent relationships (e.g., A -> B -> C -> A).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owl_class: Owlready2 class object to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circular inheritance is detected, False otherwise
|
||||||
|
"""
|
||||||
|
visited = set()
|
||||||
|
current = owl_class
|
||||||
|
|
||||||
|
while current:
|
||||||
|
# Get class IRI or name as identifier
|
||||||
|
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
|
||||||
|
|
||||||
|
if class_id in visited:
|
||||||
|
# Found a cycle
|
||||||
|
return True
|
||||||
|
|
||||||
|
visited.add(class_id)
|
||||||
|
|
||||||
|
# Get parent classes (is_a relationship)
|
||||||
|
parents = getattr(current, 'is_a', [])
|
||||||
|
|
||||||
|
# Filter out Thing and other base classes
|
||||||
|
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
|
||||||
|
|
||||||
|
if not parent_classes:
|
||||||
|
# No more parents, no cycle
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check first parent (in single inheritance)
|
||||||
|
current = parent_classes[0] if parent_classes else None
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def export_to_owl(
|
||||||
|
self,
|
||||||
|
world: World,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
format: str = "rdfxml",
|
||||||
|
classes: Optional[List] = None
|
||||||
|
) -> str:
|
||||||
|
"""Export ontology to OWL file in specified format.
|
||||||
|
|
||||||
|
Supported formats:
|
||||||
|
- rdfxml: RDF/XML format (default, most compatible)
|
||||||
|
- turtle: Turtle format (more readable)
|
||||||
|
- ntriples: N-Triples format (simplest)
|
||||||
|
- json: JSON format (simplified, human-readable)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
world: Owlready2 World object containing the ontology
|
||||||
|
output_path: Optional file path to save the ontology (if None, returns string)
|
||||||
|
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
|
||||||
|
classes: Optional list of OntologyClass objects (required for json format)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String representation of the exported ontology
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If format is not supported
|
||||||
|
RuntimeError: If export fails
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OWLValidator()
|
||||||
|
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||||
|
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
|
||||||
|
"""
|
||||||
|
# Validate format
|
||||||
|
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||||
|
if format not in valid_formats:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON format doesn't need OWL processing
|
||||||
|
if format == "json":
|
||||||
|
if not classes:
|
||||||
|
raise ValueError("Classes list is required for JSON format export")
|
||||||
|
return self._export_to_json(classes)
|
||||||
|
|
||||||
|
# For OWL formats, world is required
|
||||||
|
if not world:
|
||||||
|
raise ValueError("World object is None. Cannot export ontology.")
|
||||||
|
|
||||||
|
# Note: Owlready2 has issues with turtle format export
|
||||||
|
# We'll handle it specially by converting from rdfxml
|
||||||
|
use_conversion = (format == "turtle")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all ontologies in the world
|
||||||
|
ontologies = list(world.ontologies.values())
|
||||||
|
|
||||||
|
if not ontologies:
|
||||||
|
raise RuntimeError("No ontologies found in world")
|
||||||
|
|
||||||
|
# Find the ontology with classes (skip anonymous/empty ontologies)
|
||||||
|
onto = None
|
||||||
|
for ont in ontologies:
|
||||||
|
classes_count = len(list(ont.classes()))
|
||||||
|
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
|
||||||
|
if classes_count > 0:
|
||||||
|
onto = ont
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no ontology with classes found, use the last non-anonymous one
|
||||||
|
if onto is None:
|
||||||
|
for ont in reversed(ontologies):
|
||||||
|
if ont.base_iri != "http://anonymous/":
|
||||||
|
onto = ont
|
||||||
|
break
|
||||||
|
|
||||||
|
# If still no ontology, use the first one
|
||||||
|
if onto is None:
|
||||||
|
onto = ontologies[0]
|
||||||
|
|
||||||
|
# Log ontology contents for debugging
|
||||||
|
logger.info(f"Ontology IRI: {onto.base_iri}")
|
||||||
|
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
|
||||||
|
|
||||||
|
# List all classes in the ontology
|
||||||
|
all_classes = list(onto.classes())
|
||||||
|
for cls in all_classes:
|
||||||
|
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
|
||||||
|
if hasattr(cls, 'label'):
|
||||||
|
logger.debug(f" Labels: {cls.label}")
|
||||||
|
if hasattr(cls, 'comment'):
|
||||||
|
logger.debug(f" Comments: {cls.comment}")
|
||||||
|
|
||||||
|
if len(all_classes) == 0:
|
||||||
|
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
# Save to file
|
||||||
|
export_format = "rdfxml" if use_conversion else format
|
||||||
|
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
|
||||||
|
onto.save(file=output_path, format=export_format)
|
||||||
|
|
||||||
|
# Read back the file content to return
|
||||||
|
with open(output_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Convert to turtle if needed
|
||||||
|
if use_conversion:
|
||||||
|
content = self._convert_to_turtle(content)
|
||||||
|
|
||||||
|
logger.info(f"Successfully exported ontology to {output_path}")
|
||||||
|
|
||||||
|
# Format the content for better readability
|
||||||
|
content = self._format_owl_content(content, format)
|
||||||
|
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
# Export to string (save to temporary location and read)
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
export_format = "rdfxml" if use_conversion else format
|
||||||
|
onto.save(file=tmp_path, format=export_format)
|
||||||
|
|
||||||
|
with open(tmp_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Convert to turtle if needed
|
||||||
|
if use_conversion:
|
||||||
|
content = self._convert_to_turtle(content)
|
||||||
|
|
||||||
|
# Format the content for better readability
|
||||||
|
content = self._format_owl_content(content, format)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to export ontology: {str(e)}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
|
||||||
|
def _export_to_json(self, classes: List) -> str:
|
||||||
|
"""Export ontology classes to simplified JSON format.
|
||||||
|
|
||||||
|
This format is more compact and easier to parse than OWL XML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes: List of OntologyClass objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation (compact format)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"ontology": {
|
||||||
|
"namespace": self.base_namespace,
|
||||||
|
"classes": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for cls in classes:
|
||||||
|
class_data = {
|
||||||
|
"name": cls.name,
|
||||||
|
"name_chinese": cls.name_chinese,
|
||||||
|
"description": cls.description,
|
||||||
|
"entity_type": cls.entity_type,
|
||||||
|
"domain": cls.domain,
|
||||||
|
"parent_class": cls.parent_class,
|
||||||
|
"examples": cls.examples if hasattr(cls, 'examples') else []
|
||||||
|
}
|
||||||
|
result["ontology"]["classes"].append(class_data)
|
||||||
|
|
||||||
|
# 使用紧凑格式:无缩进,使用分隔符减少空格
|
||||||
|
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
|
||||||
|
|
||||||
|
def _convert_to_turtle(self, rdfxml_content: str) -> str:
|
||||||
|
"""Convert RDF/XML content to Turtle format using rdflib.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rdfxml_content: RDF/XML format content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Turtle format content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from rdflib import Graph
|
||||||
|
|
||||||
|
# Parse RDF/XML
|
||||||
|
g = Graph()
|
||||||
|
g.parse(data=rdfxml_content, format="xml")
|
||||||
|
|
||||||
|
# Serialize to Turtle
|
||||||
|
turtle_content = g.serialize(format="turtle")
|
||||||
|
|
||||||
|
# Handle bytes vs string
|
||||||
|
if isinstance(turtle_content, bytes):
|
||||||
|
turtle_content = turtle_content.decode('utf-8')
|
||||||
|
|
||||||
|
return turtle_content
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"rdflib is not installed. Cannot convert to Turtle format. "
|
||||||
|
"Install with: pip install rdflib"
|
||||||
|
)
|
||||||
|
return rdfxml_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert to Turtle format: {e}")
|
||||||
|
return rdfxml_content
|
||||||
|
|
||||||
|
def _format_owl_content(self, content: str, format: str) -> str:
|
||||||
|
"""Format OWL content for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw OWL content string
|
||||||
|
format: Format type (rdfxml, turtle, ntriples)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted OWL content string
|
||||||
|
"""
|
||||||
|
if format == "rdfxml":
|
||||||
|
# Format XML with proper indentation
|
||||||
|
try:
|
||||||
|
import xml.dom.minidom as minidom
|
||||||
|
dom = minidom.parseString(content)
|
||||||
|
# Pretty print with 2-space indentation
|
||||||
|
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
|
||||||
|
|
||||||
|
# Remove extra blank lines
|
||||||
|
lines = []
|
||||||
|
prev_blank = False
|
||||||
|
for line in formatted.split('\n'):
|
||||||
|
is_blank = not line.strip()
|
||||||
|
if not (is_blank and prev_blank): # Skip consecutive blank lines
|
||||||
|
lines.append(line)
|
||||||
|
prev_blank = is_blank
|
||||||
|
|
||||||
|
formatted = '\n'.join(lines)
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to format XML content: {e}")
|
||||||
|
return content
|
||||||
|
|
||||||
|
elif format == "turtle":
|
||||||
|
# Turtle format is already relatively readable
|
||||||
|
# Just ensure consistent line endings and not empty
|
||||||
|
if not content or content.strip() == "":
|
||||||
|
logger.warning("Turtle content is empty, this may indicate an export issue")
|
||||||
|
return content.strip() + '\n' if content.strip() else content
|
||||||
|
|
||||||
|
elif format == "ntriples":
|
||||||
|
# N-Triples format is line-based, ensure proper line endings
|
||||||
|
return content.strip() + '\n' if content.strip() else content
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def validate_with_protege_compatibility(
|
||||||
|
self,
|
||||||
|
classes: List[OntologyClass]
|
||||||
|
) -> Tuple[bool, List[str]]:
|
||||||
|
"""Validate that ontology classes are compatible with Protégé editor.
|
||||||
|
|
||||||
|
Protégé compatibility checks:
|
||||||
|
- Class names are valid OWL identifiers
|
||||||
|
- No special characters that Protégé cannot handle
|
||||||
|
- Namespace is properly formatted
|
||||||
|
- Labels and comments are properly encoded
|
||||||
|
|
||||||
|
Args:
|
||||||
|
classes: List of OntologyClass objects to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_compatible, warnings):
|
||||||
|
- is_compatible: True if compatible with Protégé, False otherwise
|
||||||
|
- warnings: List of compatibility warning messages
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> validator = OWLValidator()
|
||||||
|
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
|
||||||
|
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
|
||||||
|
>>> is_compatible
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check namespace format
|
||||||
|
if not self.base_namespace.startswith(('http://', 'https://')):
|
||||||
|
warnings.append(
|
||||||
|
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
|
||||||
|
"for Protégé compatibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.base_namespace.endswith(('#', '/')):
|
||||||
|
warnings.append(
|
||||||
|
f"Namespace '{self.base_namespace}' should end with # or / "
|
||||||
|
"for Protégé compatibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check each class
|
||||||
|
for ontology_class in classes:
|
||||||
|
# Check for special characters that might cause issues
|
||||||
|
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
|
||||||
|
warnings.append(
|
||||||
|
f"Class name '{ontology_class.name}' contains special characters "
|
||||||
|
"that may cause issues in Protégé"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check description length (Protégé can handle long descriptions but may display poorly)
|
||||||
|
if ontology_class.description and len(ontology_class.description) > 1000:
|
||||||
|
warnings.append(
|
||||||
|
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
|
||||||
|
"which may display poorly in Protégé"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
|
||||||
|
if not ontology_class.name.isascii():
|
||||||
|
warnings.append(
|
||||||
|
f"Class name '{ontology_class.name}' contains non-ASCII characters "
|
||||||
|
"which may cause encoding issues in some Protégé versions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no warnings, it's compatible
|
||||||
|
is_compatible = len(warnings) == 0
|
||||||
|
|
||||||
|
return is_compatible, warnings
|
||||||
1
api/app/core/models/scripts/__init__.py
Normal file
1
api/app/core/models/scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""模型配置脚本模块"""
|
||||||
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
provider: bedrock
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: ai21
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon nova
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: bedrock
|
||||||
|
- name: anthropic claude
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- document
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: deepseek
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: meta
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: mistral
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: openai
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: qwen
|
||||||
|
type: llm
|
||||||
|
provider: bedrock
|
||||||
|
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.rerank-v1:0
|
||||||
|
type: rerank
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.rerank-v3-5:0
|
||||||
|
type: rerank
|
||||||
|
provider: bedrock
|
||||||
|
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
- vision
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.titan-embed-text-v1
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: amazon.titan-embed-text-v2:0
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.embed-english-v3
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
|
- name: cohere.embed-multilingual-v3
|
||||||
|
type: embedding
|
||||||
|
provider: bedrock
|
||||||
|
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本嵌入模型
|
||||||
|
logo: bedrock
|
||||||
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
@@ -0,0 +1,820 @@
|
|||||||
|
provider: dashscope
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: deepseek-r1-distill-qwen-14b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-r1-distill-qwen-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-r1
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.1
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.2-exp
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3.2
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: deepseek-v3
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: farui-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: glm-4.7
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qvq-max-latest
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qvq-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-coder-turbo-0919
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max-latest
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max-longcontext
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-mt-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 翻译模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-mt-turbo
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 翻译模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0112
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0125
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0723
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0806
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-0919
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1125
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1127
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-plus-1220
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-max
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-0809
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-2025-01-02
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-2025-01-25
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus-latest
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen-vl-plus
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen2.5-0.5b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-14b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b-instruct-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b-thinking-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-235b-a22b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-30b-a3b-instruct-2507
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-30b-a3b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-4b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-8b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-30b-a3b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-480b-a35b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-plus-2025-09-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-coder-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
- agent-thought
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-2025-09-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-2026-01-23
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max-preview
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-max
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- 联网搜索
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-next-80b-a3b-instruct
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-next-80b-a3b-thinking
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-omni-flash-2025-12-01
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-235b-a22b-instruct
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-235b-a22b-thinking
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-30b-a3b-instruct
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-30b-a3b-thinking
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-flash
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-plus-2025-09-23
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwen3-vl-plus
|
||||||
|
type: chat
|
||||||
|
provider: dashscope
|
||||||
|
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
- agent-thought
|
||||||
|
- video
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-32b
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-plus-0305
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: qwq-plus
|
||||||
|
type: llm
|
||||||
|
provider: dashscope
|
||||||
|
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: dashscope
|
||||||
|
- name: gte-rerank-v2
|
||||||
|
type: rerank
|
||||||
|
provider: dashscope
|
||||||
|
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: dashscope
|
||||||
|
- name: gte-rerank
|
||||||
|
type: rerank
|
||||||
|
provider: dashscope
|
||||||
|
description: gte-rerank重排序模型,4000上下文窗口
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 重排序模型
|
||||||
|
logo: dashscope
|
||||||
|
- name: multimodal-embedding-v1
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 多模态模型
|
||||||
|
- vision
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v1
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v2
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v3
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
|
- name: text-embedding-v4
|
||||||
|
type: embedding
|
||||||
|
provider: dashscope
|
||||||
|
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 嵌入模型
|
||||||
|
- 文本嵌入
|
||||||
|
logo: dashscope
|
||||||
143
api/app/core/models/scripts/loader.py
Normal file
143
api/app/core/models/scripts/loader.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from app.models.models_model import ModelBase, ModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||||
|
"""从YAML文件加载指定供应商的模型配置"""
|
||||||
|
config_dir = Path(__file__).parent
|
||||||
|
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# 检查是否需要加载(默认为 true)
|
||||||
|
if not data.get('enabled', True):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return data.get('models', [])
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||||
|
"""将YAML文件的enabled标志设置为false"""
|
||||||
|
config_dir = Path(__file__).parent
|
||||||
|
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||||
|
|
||||||
|
if not config_file.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
data['enabled'] = False
|
||||||
|
|
||||||
|
with open(config_file, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
加载模型配置到数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
providers: 要加载的供应商列表,None表示加载所有
|
||||||
|
silent: 是否静默模式(不输出详细日志)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 加载结果统计 {"success": int, "skipped": int, "failed": int}
|
||||||
|
"""
|
||||||
|
result = {"success": 0, "skipped": 0, "failed": 0}
|
||||||
|
|
||||||
|
# 确定要加载的供应商
|
||||||
|
if providers:
|
||||||
|
target_providers = [ModelProvider(p) if isinstance(p, str) else p for p in providers]
|
||||||
|
else:
|
||||||
|
target_providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
|
||||||
|
for provider in target_providers:
|
||||||
|
# 从YAML文件加载模型配置
|
||||||
|
models = _load_yaml_config(provider)
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
if not silent:
|
||||||
|
print(f"警告: 供应商 '{provider.value}' 暂无预定义模型")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not silent:
|
||||||
|
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||||
|
|
||||||
|
# provider_success = 0
|
||||||
|
for model_data in models:
|
||||||
|
try:
|
||||||
|
# 检查模型是否已存在
|
||||||
|
existing = db.query(ModelBase).filter(
|
||||||
|
ModelBase.name == model_data["name"],
|
||||||
|
ModelBase.provider == model_data["provider"]
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 更新现有模型配置
|
||||||
|
for key, value in model_data.items():
|
||||||
|
setattr(existing, key, value)
|
||||||
|
db.commit()
|
||||||
|
if not silent:
|
||||||
|
print(f"更新成功: {model_data['name']}")
|
||||||
|
result["success"] += 1
|
||||||
|
# provider_success += 1
|
||||||
|
else:
|
||||||
|
# 创建新模型
|
||||||
|
model = ModelBase(**model_data)
|
||||||
|
db.add(model)
|
||||||
|
db.commit()
|
||||||
|
if not silent:
|
||||||
|
print(f"添加成功: {model_data['name']}")
|
||||||
|
result["success"] += 1
|
||||||
|
# provider_success += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
if not silent:
|
||||||
|
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||||
|
result["failed"] += 1
|
||||||
|
|
||||||
|
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||||
|
# if provider_success == len(models):
|
||||||
|
_disable_yaml_config(provider)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_by_provider(db: Session, provider: str) -> dict:
|
||||||
|
"""
|
||||||
|
加载指定供应商的模型配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider: 供应商名称(字符串或ModelProvider枚举)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 加载结果统计
|
||||||
|
"""
|
||||||
|
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||||
|
return load_models(db, providers=[provider_enum])
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_providers() -> list[Callable[[], str]]:
|
||||||
|
"""获取所有可用的供应商列表(从ModelProvider枚举获取,排除COMPOSITE)"""
|
||||||
|
return [p.value for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||||
|
|
||||||
|
|
||||||
|
def get_models_by_provider(provider: str) -> list[dict]:
|
||||||
|
"""获取指定供应商的模型配置列表"""
|
||||||
|
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||||
|
return _load_yaml_config(provider_enum)
|
||||||
294
api/app/core/models/scripts/openai_models.yaml
Normal file
294
api/app/core/models/scripts/openai_models.yaml
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
provider: openai
|
||||||
|
enabled: true
|
||||||
|
models:
|
||||||
|
- name: chatgpt-4o-latest
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-0125
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-1106
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-16k
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo-instruct
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-3.5-turbo
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-0125-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-1106-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo-2024-04-09
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
logo: openai
|
||||||
|
- name: gpt-4-turbo
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
logo: openai
|
||||||
|
- name: o1-preview
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||||
|
is_deprecated: true
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
logo: openai
|
||||||
|
- name: o1
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
- stream-tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-2025-04-16
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-mini-2025-01-31
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-mini
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-pro-2025-06-10
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3-pro
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o3
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o4-mini-2025-04-16
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: o4-mini
|
||||||
|
type: llm
|
||||||
|
provider: openai
|
||||||
|
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
- vision
|
||||||
|
- stream-tool-call
|
||||||
|
- structured-output
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-3-large
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-3-small
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
|
- name: text-embedding-ada-002
|
||||||
|
type: embedding
|
||||||
|
provider: openai
|
||||||
|
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
tags:
|
||||||
|
- 文本向量模型
|
||||||
|
logo: openai
|
||||||
@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
|
|||||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def knowledge_retrieval(
|
def knowledge_retrieval(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -62,7 +64,15 @@ def knowledge_retrieval(
|
|||||||
merge_strategy = config.get("merge_strategy", "weight")
|
merge_strategy = config.get("merge_strategy", "weight")
|
||||||
reranker_id = config.get("reranker_id")
|
reranker_id = config.get("reranker_id")
|
||||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
# use_graph = config.get("use_graph", "false").lower() == "true"
|
||||||
|
|
||||||
|
use_graph_value = config.get("use_graph", False)
|
||||||
|
if isinstance(use_graph_value, bool):
|
||||||
|
use_graph = use_graph_value
|
||||||
|
elif isinstance(use_graph_value, str):
|
||||||
|
use_graph = use_graph_value.lower() in ("true", "1", "yes")
|
||||||
|
else:
|
||||||
|
use_graph = False
|
||||||
|
|
||||||
file_names_filter = []
|
file_names_filter = []
|
||||||
if user_ids:
|
if user_ids:
|
||||||
@@ -159,13 +169,29 @@ def knowledge_retrieval(
|
|||||||
|
|
||||||
# Use the specified reranker for re-ranking
|
# Use the specified reranker for re-ranking
|
||||||
if reranker_id:
|
if reranker_id:
|
||||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
try:
|
||||||
# use graph
|
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||||
|
except Exception as rerank_error:
|
||||||
|
# If reranker fails, log warning and continue with original results
|
||||||
|
logger.warning(
|
||||||
|
"Reranker failed, falling back to original results",
|
||||||
|
extra={
|
||||||
|
"reranker_id": reranker_id,
|
||||||
|
"query": query,
|
||||||
|
"doc_count": len(all_results),
|
||||||
|
"error": str(rerank_error),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if use_graph:
|
if use_graph:
|
||||||
from app.core.rag.common.settings import kg_retriever
|
try:
|
||||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
from app.core.rag.common.settings import kg_retriever
|
||||||
if doc:
|
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
all_results.insert(0, doc)
|
if doc:
|
||||||
|
all_results.insert(0, doc)
|
||||||
|
except Exception as graph_error:
|
||||||
|
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
|||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.template_renderer import render_template
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -157,12 +156,137 @@ class WorkflowExecutor:
|
|||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_end_activate(self, node_id):
|
def _update_scope_activate(self, scope, status=None):
|
||||||
|
"""
|
||||||
|
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||||
|
|
||||||
|
Iterates over all End nodes in `self.end_outputs` and calls
|
||||||
|
`update_activate` on each, which may:
|
||||||
|
- Activate variable segments that depend on the completed node/scope.
|
||||||
|
- Activate the entire End node output if all control conditions are met.
|
||||||
|
|
||||||
|
If any End node becomes active and `self.activate_end` is not yet set,
|
||||||
|
this node will be marked as the currently active End node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): The node ID or scope that has completed execution.
|
||||||
|
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||||
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(node_id)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and self.activate_end is None:
|
||||||
self.activate_end = node
|
self.activate_end = node
|
||||||
|
|
||||||
|
def _update_stream_output_status(self, activate, data):
|
||||||
|
"""
|
||||||
|
Update the stream output state of End nodes based on workflow state updates.
|
||||||
|
|
||||||
|
This method checks which nodes/scopes are activated and propagates
|
||||||
|
activation to End nodes accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||||
|
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
For each node in `data`:
|
||||||
|
1. If the node is activated (`activate[node_id]` is True),
|
||||||
|
retrieve its output status from `runtime_vars`.
|
||||||
|
2. Call `_update_scope_activate` to propagate the activation
|
||||||
|
to all relevant End nodes and update `self.activate_end`.
|
||||||
|
"""
|
||||||
|
for node_id in data.keys():
|
||||||
|
if activate.get(node_id):
|
||||||
|
node_output_status = (
|
||||||
|
data[node_id]
|
||||||
|
.get('runtime_vars', {})
|
||||||
|
.get(node_id)
|
||||||
|
.get("output")
|
||||||
|
)
|
||||||
|
self._update_scope_activate(node_id, status=node_output_status)
|
||||||
|
|
||||||
|
async def _emit_active_chunks(
|
||||||
|
self,
|
||||||
|
node_outputs: dict,
|
||||||
|
variables: dict,
|
||||||
|
force=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process and yield all currently active output segments for the currently active End node.
|
||||||
|
|
||||||
|
This method handles stream-mode output for an End node by iterating through its output segments
|
||||||
|
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||||
|
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||||
|
2. For each segment:
|
||||||
|
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||||
|
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||||
|
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||||
|
then transform the result with `_trans_output_string`.
|
||||||
|
3. Yield a stream event of type "message" containing the processed chunk.
|
||||||
|
4. Move the `cursor` forward after processing each segment.
|
||||||
|
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||||
|
and reset `activate_end` to None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
||||||
|
variables (dict): Current runtime variables, used for variable evaluation.
|
||||||
|
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: A stream event of type "message" containing the processed chunk.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||||
|
- This method only processes the currently active End node (`self.activate_end`).
|
||||||
|
- Use `force=True` for final emission regardless of activation state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
|
||||||
|
while end_info.cursor < len(end_info.outputs):
|
||||||
|
final_chunk = ''
|
||||||
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
|
if not current_segment.activate and not force:
|
||||||
|
# Stop processing until this segment becomes active
|
||||||
|
break
|
||||||
|
|
||||||
|
# Literal segment
|
||||||
|
if not current_segment.is_variable:
|
||||||
|
final_chunk += current_segment.literal
|
||||||
|
else:
|
||||||
|
# Variable segment: evaluate and transform
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
current_segment.literal,
|
||||||
|
variables=variables,
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
chunk = self._trans_output_string(chunk)
|
||||||
|
final_chunk += chunk
|
||||||
|
except ValueError:
|
||||||
|
# Log failed evaluation but continue streaming
|
||||||
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||||
|
|
||||||
|
if final_chunk:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": final_chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Advance cursor after processing
|
||||||
|
end_info.cursor += 1
|
||||||
|
|
||||||
|
# Remove End node from active tracking if all segments have been processed
|
||||||
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trans_output_string(content):
|
def _trans_output_string(content):
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -218,14 +342,8 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
full_content = ''
|
full_content = ''
|
||||||
for end_info in self.end_outputs.values():
|
for end_id in self.end_outputs.keys():
|
||||||
output_template = "".join([output.literal for output in end_info.outputs])
|
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
||||||
full_content += render_template(
|
|
||||||
output_template,
|
|
||||||
result.get("variables", {}),
|
|
||||||
result.get("runtime_vars", {}),
|
|
||||||
strict=False
|
|
||||||
)
|
|
||||||
result["messages"].extend(
|
result["messages"].extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@@ -306,7 +424,7 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
full_content = ''
|
full_content = ''
|
||||||
|
self._update_scope_activate("sys")
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
@@ -333,9 +451,12 @@ class WorkflowExecutor:
|
|||||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||||
continue
|
continue
|
||||||
current_output = end_info.outputs[end_info.cursor]
|
current_output = end_info.outputs[end_info.cursor]
|
||||||
if current_output.is_variable and current_output.depends_on_node(node_id):
|
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||||
if data.get("done"):
|
if data.get("done"):
|
||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
else:
|
else:
|
||||||
full_content += data.get("chunk")
|
full_content += data.get("chunk")
|
||||||
yield {
|
yield {
|
||||||
@@ -415,91 +536,53 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
for node_id in data.keys():
|
state = graph.get_state(config=self.checkpoint_config).values
|
||||||
self._update_end_activate(node_id)
|
node_outputs = state.get("runtime_vars", {})
|
||||||
wait = False
|
variables = state.get("variables", {})
|
||||||
state = graph.get_state(config=self.checkpoint_config)
|
activate = state.get("activate", {})
|
||||||
node_outputs = state.values.get("runtime_vars", {})
|
for _, node_data in data.items():
|
||||||
for _ in data.keys():
|
node_outputs |= node_data.get("runtime_vars", {})
|
||||||
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
variables |= node_data.get("variables", {})
|
||||||
|
|
||||||
|
self._update_stream_output_status(activate, data)
|
||||||
|
wait = False
|
||||||
while self.activate_end and not wait:
|
while self.activate_end and not wait:
|
||||||
message = ''
|
async for msg_event in self._emit_active_chunks(
|
||||||
logger.info(self.activate_end)
|
node_outputs=node_outputs,
|
||||||
end_info = self.end_outputs[self.activate_end]
|
variables=variables
|
||||||
content = end_info.outputs[end_info.cursor]
|
):
|
||||||
while content.activate:
|
full_content += msg_event["data"]['chunk']
|
||||||
if not content.is_variable:
|
yield msg_event
|
||||||
full_content += content.literal
|
|
||||||
message += content.literal
|
if self.activate_end:
|
||||||
else:
|
|
||||||
try:
|
|
||||||
chunk = evaluate_expression(
|
|
||||||
content.literal,
|
|
||||||
variables={},
|
|
||||||
node_outputs=node_outputs
|
|
||||||
)
|
|
||||||
chunk = self._trans_output_string(chunk)
|
|
||||||
message += chunk
|
|
||||||
full_content += chunk
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
end_info.cursor += 1
|
|
||||||
if end_info.cursor == len(end_info.outputs):
|
|
||||||
break
|
|
||||||
content = end_info.outputs[end_info.cursor]
|
|
||||||
if end_info.cursor != len(end_info.outputs):
|
|
||||||
wait = True
|
wait = True
|
||||||
else:
|
else:
|
||||||
self.end_outputs.pop(self.activate_end)
|
self._update_stream_output_status(activate, data)
|
||||||
self.activate_end = None
|
|
||||||
for node_id in data.keys():
|
|
||||||
self._update_end_activate(node_id)
|
|
||||||
if message:
|
|
||||||
yield {
|
|
||||||
"event": "message",
|
|
||||||
"data": {
|
|
||||||
"chunk": message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
result = graph.get_state(self.checkpoint_config).values
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
while self.activate_end:
|
node_outputs = result.get("runtime_vars", {})
|
||||||
message = ''
|
variables = result.get("variables", {})
|
||||||
end_info = self.end_outputs[self.activate_end]
|
self.end_outputs = {
|
||||||
content = end_info.outputs[end_info.cursor]
|
node_id: node_info
|
||||||
if not content.is_variable:
|
for node_id, node_info in self.end_outputs.items()
|
||||||
message += content.literal
|
if node_info.activate
|
||||||
else:
|
}
|
||||||
node_outputs = result.get("runtime_vars", {})
|
|
||||||
variables = result.get("variables", {})
|
if self.end_outputs or self.activate_end:
|
||||||
try:
|
while self.activate_end:
|
||||||
chunk = evaluate_expression(
|
async for msg_event in self._emit_active_chunks(
|
||||||
content.literal,
|
node_outputs=node_outputs,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
node_outputs=node_outputs
|
force=True
|
||||||
)
|
):
|
||||||
chunk = self._trans_output_string(chunk)
|
full_content += msg_event["data"]['chunk']
|
||||||
message += chunk
|
yield msg_event
|
||||||
full_content += chunk
|
|
||||||
except ValueError:
|
if not self.activate_end and self.end_outputs:
|
||||||
pass
|
|
||||||
end_info.cursor += 1
|
|
||||||
if end_info.cursor == len(end_info.outputs):
|
|
||||||
self.end_outputs.pop(self.activate_end)
|
|
||||||
self.activate_end = None
|
|
||||||
if self.end_outputs:
|
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
if message:
|
|
||||||
yield {
|
|
||||||
"event": "message",
|
|
||||||
"data": {
|
|
||||||
"chunk": message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
|||||||
@@ -53,114 +53,110 @@ class OutputContent(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def depends_on_node(self, node_id: str) -> bool:
|
def depends_on_scope(self, scope: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this output segment depends on a specific node's variable.
|
Check if this segment depends on a given scope.
|
||||||
|
|
||||||
This method examines the `literal` of the output segment to see if it
|
|
||||||
contains a variable placeholder referencing the given node in the form:
|
|
||||||
|
|
||||||
{{ node_id.field_name }}
|
|
||||||
|
|
||||||
It uses a regular expression to match the exact node ID, avoiding
|
|
||||||
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool:
|
bool: True if this segment references the given scope.
|
||||||
- True if the segment contains a variable referencing the given node.
|
|
||||||
- False otherwise.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
literal = "{{node1.name}}"
|
|
||||||
|
|
||||||
depends_on_node("node1") -> True
|
|
||||||
depends_on_node("node2") -> False
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
This method is primarily used in stream mode to determine whether
|
|
||||||
a particular variable output segment should be activated when a
|
|
||||||
specific upstream node completes execution.
|
|
||||||
"""
|
"""
|
||||||
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||||
pattern = re.compile(variable_pattern)
|
return bool(re.search(pattern, self.literal))
|
||||||
match = pattern.search(self.literal)
|
|
||||||
if match:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StreamOutputConfig(BaseModel):
|
class StreamOutputConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Streaming output configuration for an End node.
|
Streaming output configuration for an End node.
|
||||||
|
|
||||||
This structure controls:
|
This configuration describes how the End node output behaves in streaming mode,
|
||||||
- whether the End node output is globally active
|
including:
|
||||||
- which upstream branch nodes are responsible for activation
|
- whether output emission is globally activated
|
||||||
- how each output segment behaves in streaming mode
|
- which upstream branch/control nodes gate the activation
|
||||||
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation state of the End node output.\n"
|
"Global activation flag for the End node output.\n"
|
||||||
"If False, no output should be emitted until all control nodes are resolved."
|
"When False, output segments should not be emitted even if available.\n"
|
||||||
|
"This flag typically becomes True once required control branch conditions "
|
||||||
|
"are satisfied."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
control_nodes: list[str] = Field(
|
control_nodes: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"List of upstream branch node IDs that control this End node.\n"
|
"Control branch conditions for this End node output.\n"
|
||||||
"Each node must signal completion before output becomes active."
|
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||||
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description="Ordered list of output segments parsed from the output template."
|
description=(
|
||||||
|
"Ordered list of output segments parsed from the output template.\n"
|
||||||
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
|
"that may be activated independently."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index.\n"
|
||||||
"Indicates how many output segments have already been emitted."
|
"Indicates the next output segment index to be emitted.\n"
|
||||||
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_activate(self, node_id):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update activation state based on an upstream node completion.
|
Update streaming activation state based on an upstream node or special variable.
|
||||||
|
|
||||||
This method is typically called when a branch/control node finishes execution.
|
Args:
|
||||||
|
scope (str):
|
||||||
|
Identifier of the completed upstream entity.
|
||||||
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||||
|
status (optional):
|
||||||
|
Completion status of the control branch node.
|
||||||
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. If the node is a control node:
|
1. Control branch nodes:
|
||||||
- Remove it from `control_nodes`
|
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||||
- If all control nodes are resolved, activate the entire output
|
branch label, the End node output becomes globally active (`activate = True`).
|
||||||
|
|
||||||
2. Activate variable output segments that depend on this node:
|
2. Variable output segments:
|
||||||
- If an output segment is a variable
|
- For each segment that is a variable (`is_variable=True`):
|
||||||
- And its literal references the completed node_id
|
- If the segment literal references `scope`, mark the segment as active.
|
||||||
- Mark that segment as active
|
- This applies both to regular node variables (e.g., "node_id.field")
|
||||||
|
and special system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This method does not emit output or advance the streaming cursor.
|
||||||
|
- It only updates activation flags based on upstream events or special variables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
# Case 1: resolve control branch dependency
|
||||||
if node_id in self.control_nodes:
|
if scope in self.control_nodes.keys():
|
||||||
self.control_nodes.remove(node_id)
|
if status is None:
|
||||||
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
# All branch constraints resolved → enable output
|
if status == self.control_nodes[scope]:
|
||||||
if not self.control_nodes:
|
|
||||||
self.activate = True
|
self.activate = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
# Case 2: activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
and self.outputs[i].depends_on_node(node_id)
|
and self.outputs[i].depends_on_scope(scope)
|
||||||
):
|
):
|
||||||
self.outputs[i].activate = True
|
self.outputs[i].activate = True
|
||||||
|
|
||||||
@@ -184,11 +180,11 @@ class GraphBuilder:
|
|||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_branch_node = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_branch_node)
|
||||||
self._analyze_end_node_output()
|
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
|
self._analyze_end_node_output()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -216,30 +212,53 @@ class GraphBuilder:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||||
"""Find upstream branch nodes for a given target node in the workflow graph.
|
"""
|
||||||
|
Recursively find all upstream branch (control) nodes that influence the execution
|
||||||
|
of the given target node.
|
||||||
|
|
||||||
This method identifies all upstream control (branch) nodes that can affect
|
This method walks upstream along the workflow graph starting from `target_node`.
|
||||||
the execution of `target_node`. If `target_node` is reachable from a start
|
It distinguishes between:
|
||||||
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||||
|
- non-branch nodes (ordinary processing nodes)
|
||||||
|
|
||||||
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
Traversal rules:
|
||||||
and non-branch nodes, recursively traversing upstream through non-branch
|
1. For each immediate upstream node:
|
||||||
nodes. If any non-branch upstream path does not lead to a branch node,
|
- If it is a branch node, it is recorded as an affecting control node.
|
||||||
the result will indicate that no valid upstream branch node exists.
|
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||||
|
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||||
|
a branch node, the traversal is considered invalid:
|
||||||
|
- `has_branch` will be False
|
||||||
|
- no branch nodes are returned.
|
||||||
|
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||||
|
branch node will `has_branch` be True.
|
||||||
|
|
||||||
|
Special case:
|
||||||
|
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||||
|
it is considered directly reachable from the workflow entry, and therefore
|
||||||
|
has no controlling branch nodes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str): The identifier of the target node.
|
target_node (str):
|
||||||
|
The identifier of the node whose upstream control branches
|
||||||
|
are to be resolved.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[str]]:
|
tuple[bool, tuple[tuple[str, str]]]:
|
||||||
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
- has_branch (bool):
|
||||||
one branch node; False if any path reaches a start node without a branch.
|
True if every upstream path from `target_node` encounters
|
||||||
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
at least one branch node.
|
||||||
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
False if any path reaches a start node without a branch.
|
||||||
|
- branch_nodes (tuple[tuple[str, str]]):
|
||||||
|
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||||
|
representing all branch nodes that can influence `target_node`.
|
||||||
|
Returns an empty tuple if `has_branch` is False.
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = [
|
||||||
edge.get("source")
|
{
|
||||||
|
"id": edge.get("source"),
|
||||||
|
"branch": edge.get("label")
|
||||||
|
}
|
||||||
for edge in self.edges
|
for edge in self.edges
|
||||||
if edge.get("target") == target_node
|
if edge.get("target") == target_node
|
||||||
]
|
]
|
||||||
@@ -249,11 +268,13 @@ class GraphBuilder:
|
|||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_id in source_nodes:
|
for node_info in source_nodes:
|
||||||
if self.get_node_type(node_id) in BRANCH_NODES:
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
branch_nodes.append(node_id)
|
branch_nodes.append(
|
||||||
|
(node_info["id"], node_info["branch"])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
non_branch_nodes.append(node_id)
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
@@ -334,7 +355,7 @@ class GraphBuilder:
|
|||||||
activate=not has_branch,
|
activate=not has_branch,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=list(control_nodes),
|
control_nodes=dict(control_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -362,7 +383,7 @@ class GraphBuilder:
|
|||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes=[],
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
[
|
[
|
||||||
OutputContent(
|
OutputContent(
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
|||||||
description="code content"
|
description="code content"
|
||||||
)
|
)
|
||||||
|
|
||||||
language: Literal['python3', 'nodejs'] = Field(
|
language: Literal['python3', 'javascript'] = Field(
|
||||||
...,
|
...,
|
||||||
description="language"
|
description="language"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import urllib.parse
|
||||||
from string import Template
|
from string import Template
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -14,7 +15,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SCRIPT_TEMPLATE = Template(dedent("""
|
PYTHON_SCRIPT_TEMPLATE = Template(dedent("""
|
||||||
$code
|
$code
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -32,6 +33,20 @@ result = "<<RESULT>>" + output_json + "<<RESULT>>"
|
|||||||
print(result)
|
print(result)
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
|
NODEJS_SCRIPT_TEMPLATE = Template(dedent("""
|
||||||
|
$code
|
||||||
|
// decode and prepare input object
|
||||||
|
var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8'))
|
||||||
|
|
||||||
|
// execute main function
|
||||||
|
var output_obj = main(inputs_obj)
|
||||||
|
|
||||||
|
// convert output to json and print
|
||||||
|
var output_json = JSON.stringify(output_obj)
|
||||||
|
var result = `<<RESULT>>$${output_json}<<RESULT>>`
|
||||||
|
console.log(result)
|
||||||
|
"""))
|
||||||
|
|
||||||
|
|
||||||
class CodeNode(BaseNode):
|
class CodeNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
@@ -83,18 +98,27 @@ class CodeNode(BaseNode):
|
|||||||
input_variable_dict = {}
|
input_variable_dict = {}
|
||||||
for input_variable in self.typed_config.input_variables:
|
for input_variable in self.typed_config.input_variables:
|
||||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||||
|
|
||||||
code = base64.b64decode(
|
code = base64.b64decode(
|
||||||
self.typed_config.code
|
self.typed_config.code
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
code = urllib.parse.unquote(code, encoding='utf-8')
|
||||||
|
|
||||||
input_variable_dict = base64.b64encode(
|
input_variable_dict = base64.b64encode(
|
||||||
json.dumps(input_variable_dict).encode("utf-8")
|
json.dumps(input_variable_dict).encode("utf-8")
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
if self.typed_config.language == "python3":
|
||||||
final_script = SCRIPT_TEMPLATE.substitute(
|
final_script = PYTHON_SCRIPT_TEMPLATE.substitute(
|
||||||
code=code,
|
code=code,
|
||||||
inputs_variable=input_variable_dict,
|
inputs_variable=input_variable_dict,
|
||||||
)
|
)
|
||||||
|
elif self.typed_config.language == 'javascript':
|
||||||
|
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
||||||
|
code=code,
|
||||||
|
inputs_variable=input_variable_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
|||||||
@@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
|||||||
...
|
...
|
||||||
)
|
)
|
||||||
|
|
||||||
config_id: UUID = Field(
|
config_id: UUID | int = Field(
|
||||||
...
|
...
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode):
|
|||||||
class MemoryWriteNode(BaseNode):
|
class MemoryWriteNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
end_user_id = self.get_variable("sys.user_id", state)
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
if not end_user_id:
|
if not end_user_id:
|
||||||
|
|||||||
@@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_prompt():
|
def _get_prompt():
|
||||||
@@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
])
|
])
|
||||||
|
|
||||||
model_resp = await llm.ainvoke(messages)
|
model_resp = await llm.ainvoke(messages)
|
||||||
|
self.response_metadata = model_resp.response_metadata
|
||||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||||
logger.info(f"node: {self.node_id} get params:{result}")
|
logger.info(f"node: {self.node_id} get params:{result}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
if self.response_metadata:
|
||||||
|
usage = self.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
@@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
result = response.content.strip()
|
result = response.content.strip()
|
||||||
|
self.response_metadata = response.response_metadata
|
||||||
|
|
||||||
if result in category_names:
|
if result in category_names:
|
||||||
category = result
|
category = result
|
||||||
|
|||||||
@@ -4,16 +4,19 @@
|
|||||||
从文件系统加载预定义的工作流模板
|
从文件系统加载预定义的工作流模板
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
|
||||||
|
|
||||||
|
|
||||||
class TemplateLoader:
|
class TemplateLoader:
|
||||||
"""工作流模板加载器"""
|
"""工作流模板加载器"""
|
||||||
|
|
||||||
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
def __init__(self, templates_dir: str = TEMPLATE_DIR):
|
||||||
"""初始化模板加载器
|
"""初始化模板加载器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -22,7 +25,7 @@ class TemplateLoader:
|
|||||||
self.templates_dir = Path(templates_dir)
|
self.templates_dir = Path(templates_dir)
|
||||||
if not self.templates_dir.exists():
|
if not self.templates_dir.exists():
|
||||||
raise ValueError(f"模板目录不存在: {templates_dir}")
|
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||||
|
|
||||||
def list_templates(self) -> list[dict]:
|
def list_templates(self) -> list[dict]:
|
||||||
"""列出所有可用的模板
|
"""列出所有可用的模板
|
||||||
|
|
||||||
@@ -30,22 +33,22 @@ class TemplateLoader:
|
|||||||
模板列表,每个模板包含 id, name, description 等信息
|
模板列表,每个模板包含 id, name, description 等信息
|
||||||
"""
|
"""
|
||||||
templates = []
|
templates = []
|
||||||
|
|
||||||
# 遍历模板目录
|
# 遍历模板目录
|
||||||
for template_dir in self.templates_dir.iterdir():
|
for template_dir in self.templates_dir.iterdir():
|
||||||
if not template_dir.is_dir():
|
if not template_dir.is_dir():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否有 template.yml 文件
|
# 检查是否有 template.yml 文件
|
||||||
template_file = template_dir / "template.yml"
|
template_file = template_dir / "template.yml"
|
||||||
if not template_file.exists():
|
if not template_file.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 读取模板配置
|
# 读取模板配置
|
||||||
with open(template_file, 'r', encoding='utf-8') as f:
|
with open(template_file, 'r', encoding='utf-8') as f:
|
||||||
template_data = yaml.safe_load(f)
|
template_data = yaml.safe_load(f)
|
||||||
|
|
||||||
# 提取模板信息
|
# 提取模板信息
|
||||||
templates.append({
|
templates.append({
|
||||||
"id": template_dir.name,
|
"id": template_dir.name,
|
||||||
@@ -59,9 +62,9 @@ class TemplateLoader:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"加载模板 {template_dir.name} 失败: {e}")
|
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return templates
|
return templates
|
||||||
|
|
||||||
def load_template(self, template_id: str) -> Optional[dict]:
|
def load_template(self, template_id: str) -> Optional[dict]:
|
||||||
"""加载指定的模板
|
"""加载指定的模板
|
||||||
|
|
||||||
@@ -73,14 +76,14 @@ class TemplateLoader:
|
|||||||
"""
|
"""
|
||||||
template_dir = self.templates_dir / template_id
|
template_dir = self.templates_dir / template_id
|
||||||
template_file = template_dir / "template.yml"
|
template_file = template_dir / "template.yml"
|
||||||
|
|
||||||
if not template_file.exists():
|
if not template_file.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(template_file, 'r', encoding='utf-8') as f:
|
with open(template_file, 'r', encoding='utf-8') as f:
|
||||||
template_data = yaml.safe_load(f)
|
template_data = yaml.safe_load(f)
|
||||||
|
|
||||||
# 返回工作流配置部分
|
# 返回工作流配置部分
|
||||||
return {
|
return {
|
||||||
"name": template_data.get("name", template_id),
|
"name": template_data.get("name", template_id),
|
||||||
@@ -94,7 +97,7 @@ class TemplateLoader:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"加载模板 {template_id} 失败: {e}")
|
print(f"加载模板 {template_id} 失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_template_readme(self, template_id: str) -> Optional[str]:
|
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||||
"""获取模板的 README 文档
|
"""获取模板的 README 文档
|
||||||
|
|
||||||
@@ -106,10 +109,10 @@ class TemplateLoader:
|
|||||||
"""
|
"""
|
||||||
template_dir = self.templates_dir / template_id
|
template_dir = self.templates_dir / template_id
|
||||||
readme_file = template_dir / "README.md"
|
readme_file = template_dir / "README.md"
|
||||||
|
|
||||||
if not readme_file.exists():
|
if not readme_file.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(readme_file, 'r', encoding='utf-8') as f:
|
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from app.core.error_codes import BizCode, HTTP_MAPPING
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
from app.core.response_utils import fail
|
from app.core.response_utils import fail
|
||||||
|
from app.core.models.scripts.loader import load_models
|
||||||
|
from app.db import get_db_context
|
||||||
|
|
||||||
# Initialize logging system
|
# Initialize logging system
|
||||||
LoggingConfig.setup_logging()
|
LoggingConfig.setup_logging()
|
||||||
@@ -47,6 +49,15 @@ async def lifespan(app: FastAPI):
|
|||||||
else:
|
else:
|
||||||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||||||
|
|
||||||
|
# 加载预定义模型
|
||||||
|
logger.info("开始加载预定义模型...")
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
result = load_models(db, silent=True)
|
||||||
|
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||||
|
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
yield
|
yield
|
||||||
# 应用关闭事件
|
# 应用关闭事件
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ from .tool_model import (
|
|||||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||||
)
|
)
|
||||||
from .memory_perceptual_model import MemoryPerceptualModel
|
from .memory_perceptual_model import MemoryPerceptualModel
|
||||||
|
from .ontology_scene import OntologyScene
|
||||||
|
from .ontology_class import OntologyClass
|
||||||
|
from .ontology_scene import OntologyScene
|
||||||
|
from .ontology_class import OntologyClass
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Tenants",
|
"Tenants",
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ class MemoryConfig(Base):
|
|||||||
end_user_id = Column(String, nullable=True, comment="组ID")
|
end_user_id = Column(String, nullable=True, comment="组ID")
|
||||||
user_id = Column(String, nullable=True, comment="用户ID")
|
user_id = Column(String, nullable=True, comment="用户ID")
|
||||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||||
|
|
||||||
|
# 本体场景关联
|
||||||
|
scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID,关联ontology_scene表")
|
||||||
|
|
||||||
# 模型选择(从workspace继承)
|
# 模型选择(从workspace继承)
|
||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from enum import StrEnum
|
|||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlalchemy.sql import func
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +24,8 @@ class ModelType(StrEnum):
|
|||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
RERANK = "rerank"
|
RERANK = "rerank"
|
||||||
|
# TTS = "tts"
|
||||||
|
# SPEECH2TEXT = "speech2text"
|
||||||
# IMAGE = "image"
|
# IMAGE = "image"
|
||||||
# AUDIO = "audio"
|
# AUDIO = "audio"
|
||||||
# VISION = "vision"
|
# VISION = "vision"
|
||||||
@@ -48,8 +51,7 @@ class ModelProvider(StrEnum):
|
|||||||
class LoadBalanceStrategy(StrEnum):
|
class LoadBalanceStrategy(StrEnum):
|
||||||
"""API Key负载均衡策略枚举"""
|
"""API Key负载均衡策略枚举"""
|
||||||
ROUND_ROBIN = "round_robin" # 轮询
|
ROUND_ROBIN = "round_robin" # 轮询
|
||||||
WEIGHTED_ROUND_ROBIN = "weighted_round_robin" # 加权轮询
|
NONE = "none" # 无
|
||||||
RANDOM = "random" # 随机
|
|
||||||
|
|
||||||
|
|
||||||
# 多对多关联表
|
# 多对多关联表
|
||||||
@@ -90,7 +92,8 @@ class ModelConfig(BaseModel):
|
|||||||
|
|
||||||
# 状态管理
|
# 状态管理
|
||||||
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
|
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
|
||||||
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略")
|
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略", default=LoadBalanceStrategy.NONE,
|
||||||
|
server_default=LoadBalanceStrategy.NONE)
|
||||||
|
|
||||||
# 关联关系
|
# 关联关系
|
||||||
model_base = relationship("ModelBase", back_populates="configs")
|
model_base = relationship("ModelBase", back_populates="configs")
|
||||||
@@ -151,6 +154,7 @@ class ModelBase(Base):
|
|||||||
is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)")
|
is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)")
|
||||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
||||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
|
||||||
|
|
||||||
# 关联关系
|
# 关联关系
|
||||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
||||||
|
|||||||
40
api/app/models/ontology_class.py
Normal file
40
api/app/models/ontology_class.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""本体类型模型
|
||||||
|
|
||||||
|
本模块定义本体类型的数据模型。
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyClass: 本体类型表模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyClass(Base):
|
||||||
|
"""本体类型表 - 用于存储某个场景提取出来的本体类型信息"""
|
||||||
|
__tablename__ = "ontology_class"
|
||||||
|
|
||||||
|
# 主键
|
||||||
|
class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID")
|
||||||
|
|
||||||
|
# 类型信息
|
||||||
|
class_name = Column(String(200), nullable=False, comment="类型名称")
|
||||||
|
class_description = Column(Text, nullable=True, comment="类型描述")
|
||||||
|
|
||||||
|
# 外键:关联到本体场景
|
||||||
|
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
||||||
|
|
||||||
|
# 关系:类型属于某个场景
|
||||||
|
scene = relationship("OntologyScene", back_populates="classes")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OntologyClass(id={self.class_id}, name={self.class_name}, scene_id={self.scene_id})>"
|
||||||
43
api/app/models/ontology_scene.py
Normal file
43
api/app/models/ontology_scene.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""本体场景模型
|
||||||
|
|
||||||
|
本模块定义本体场景的数据模型。
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyScene: 本体场景表模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyScene(Base):
|
||||||
|
"""本体场景表 - 用于存储本体场景下不同的类型信息"""
|
||||||
|
__tablename__ = "ontology_scene"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主键
|
||||||
|
scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID")
|
||||||
|
|
||||||
|
# 场景信息
|
||||||
|
scene_name = Column(String(200), nullable=False, comment="场景名称")
|
||||||
|
scene_description = Column(Text, nullable=True, comment="场景描述")
|
||||||
|
|
||||||
|
# 外键:关联到工作空间
|
||||||
|
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
|
||||||
|
|
||||||
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
||||||
|
|
||||||
|
# 关系:一个场景可以有多个类型
|
||||||
|
classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OntologyScene(id={self.scene_id}, name={self.scene_name})>"
|
||||||
@@ -2,7 +2,7 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
|
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
@@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base):
|
|||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||||
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
||||||
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
|
session_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("prompt_opt_session_list.id"),
|
||||||
|
nullable=False,
|
||||||
|
comment="Session ID"
|
||||||
|
)
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
||||||
role = Column(String, nullable=False, comment="Message Role")
|
role = Column(String, nullable=False, comment="Message Role")
|
||||||
content = Column(Text, nullable=False, comment="Message Content")
|
content = Column(Text, nullable=False, comment="Message Content")
|
||||||
# prompt = Column(Text, nullable=False, comment="Prompt")
|
# prompt = Column(Text, nullable=False, comment="Prompt")
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptHistory(Base):
|
||||||
|
__tablename__ = "prompt_history"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||||
|
|
||||||
|
session_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("prompt_opt_session_list.id"),
|
||||||
|
nullable=False,
|
||||||
|
comment="Session ID"
|
||||||
|
)
|
||||||
|
title = Column(String, nullable=False, comment="Title")
|
||||||
|
prompt = Column(Text, nullable=False, comment="Prompt")
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||||
|
is_delete = Column(Boolean, default=False, comment="Delete")
|
||||||
|
|||||||
@@ -24,12 +24,16 @@ from app.schemas.memory_storage_schema import (
|
|||||||
from sqlalchemy import desc, select
|
from sqlalchemy import desc, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
# 获取配置专用日志器
|
# 获取配置专用日志器
|
||||||
config_logger = get_config_logger()
|
config_logger = get_config_logger()
|
||||||
|
|
||||||
TABLE_NAME = "memory_config"
|
TABLE_NAME = "memory_config"
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfigRepository:
|
class MemoryConfigRepository:
|
||||||
"""记忆配置Repository
|
"""记忆配置Repository
|
||||||
|
|
||||||
@@ -82,7 +86,8 @@ class MemoryConfigRepository:
|
|||||||
n.description AS description,
|
n.description AS description,
|
||||||
n.entity_type AS entity_type,
|
n.entity_type AS entity_type,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
COALESCE(n.fact_summary, '') AS fact_summary,
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
// COALESCE(n.fact_summary, '') AS fact_summary,
|
||||||
n.end_user_id AS end_user_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
n.apply_id AS apply_id,
|
||||||
n.user_id AS user_id,
|
n.user_id AS user_id,
|
||||||
@@ -152,7 +157,7 @@ class MemoryConfigRepository:
|
|||||||
return memory_config_obj
|
return memory_config_obj
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
|
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -187,7 +192,6 @@ class MemoryConfigRepository:
|
|||||||
raise RuntimeError("reflection config not found")
|
raise RuntimeError("reflection config not found")
|
||||||
return memory_config
|
return memory_config
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||||
@@ -227,9 +231,12 @@ class MemoryConfigRepository:
|
|||||||
config_name=params.config_name,
|
config_name=params.config_name,
|
||||||
config_desc=params.config_desc,
|
config_desc=params.config_desc,
|
||||||
workspace_id=params.workspace_id,
|
workspace_id=params.workspace_id,
|
||||||
|
scene_id=params.scene_id,
|
||||||
llm_id=params.llm_id,
|
llm_id=params.llm_id,
|
||||||
embedding_id=params.embedding_id,
|
embedding_id=params.embedding_id,
|
||||||
rerank_id=params.rerank_id,
|
rerank_id=params.rerank_id,
|
||||||
|
reflection_model_id=params.reflection_model_id,
|
||||||
|
emotion_model_id=params.emotion_model_id,
|
||||||
)
|
)
|
||||||
db.add(db_config)
|
db.add(db_config)
|
||||||
db.flush() # 获取自增ID但不提交事务
|
db.flush() # 获取自增ID但不提交事务
|
||||||
@@ -272,6 +279,9 @@ class MemoryConfigRepository:
|
|||||||
if update.config_desc is not None:
|
if update.config_desc is not None:
|
||||||
db_config.config_desc = update.config_desc
|
db_config.config_desc = update.config_desc
|
||||||
has_update = True
|
has_update = True
|
||||||
|
if update.scene_id is not None:
|
||||||
|
db_config.scene_id = update.scene_id
|
||||||
|
has_update = True
|
||||||
|
|
||||||
if not has_update:
|
if not has_update:
|
||||||
raise ValueError("No fields to update")
|
raise ValueError("No fields to update")
|
||||||
@@ -287,7 +297,6 @@ class MemoryConfigRepository:
|
|||||||
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
|
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
|
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
|
||||||
"""更新记忆萃取引擎配置
|
"""更新记忆萃取引擎配置
|
||||||
@@ -410,7 +419,7 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
|
def get_extracted_config(db: Session, config_id: UUID | int) -> Optional[Dict]:
|
||||||
"""获取萃取配置,通过主键查询某条配置
|
"""获取萃取配置,通过主键查询某条配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -420,8 +429,8 @@ class MemoryConfigRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||||
"""
|
"""
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||||
if not db_config:
|
if not db_config:
|
||||||
@@ -514,26 +523,27 @@ class MemoryConfigRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
|
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
|
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
||||||
"""Get memory config and its associated workspace information
|
"""Get memory config and its associated workspace information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session
|
db: Database session
|
||||||
config_id: Configuration ID
|
config_id: Configuration ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
|
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Raised when config exists but workspace doesn't
|
ValueError: Raised when config exists but workspace doesn't
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.models.workspace_model import Workspace
|
from app.models.workspace_model import Workspace
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
# Log configuration loading start
|
# Log configuration loading start
|
||||||
config_logger.info(
|
config_logger.info(
|
||||||
"Loading configuration with workspace",
|
"Loading configuration with workspace",
|
||||||
@@ -542,17 +552,16 @@ class MemoryConfigRepository:
|
|||||||
"config_id": config_id
|
"config_id": config_id
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
|
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use join query to get both config and workspace
|
# Use join query to get both config and workspace
|
||||||
result = db.query(MemoryConfig, Workspace).join(
|
result = db.query(MemoryConfig, Workspace).join(
|
||||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||||
).filter(MemoryConfig.config_id == config_id).first()
|
).filter(MemoryConfig.config_id == config_id).first()
|
||||||
|
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
# Check if config exists but workspace is missing
|
# Check if config exists but workspace is missing
|
||||||
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||||
@@ -581,9 +590,11 @@ class MemoryConfigRepository:
|
|||||||
"elapsed_ms": elapsed_ms
|
"elapsed_ms": elapsed_ms
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
|
db_logger.error(
|
||||||
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||||
|
raise ValueError(
|
||||||
|
f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
||||||
|
|
||||||
config_logger.debug(
|
config_logger.debug(
|
||||||
"Configuration not found",
|
"Configuration not found",
|
||||||
extra={
|
extra={
|
||||||
@@ -595,9 +606,9 @@ class MemoryConfigRepository:
|
|||||||
)
|
)
|
||||||
db_logger.debug(f"Memory config not found: config_id={config_id}")
|
db_logger.debug(f"Memory config not found: config_id={config_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config, workspace = result
|
config, workspace = result
|
||||||
|
|
||||||
# Log successful configuration loading
|
# Log successful configuration loading
|
||||||
config_logger.info(
|
config_logger.info(
|
||||||
"Configuration with workspace loaded successfully",
|
"Configuration with workspace loaded successfully",
|
||||||
@@ -612,16 +623,17 @@ class MemoryConfigRepository:
|
|||||||
"elapsed_ms": elapsed_ms
|
"elapsed_ms": elapsed_ms
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
db_logger.debug(
|
||||||
|
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||||
return (config, workspace)
|
return (config, workspace)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise known business exceptions
|
# Re-raise known business exceptions
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
config_logger.error(
|
config_logger.error(
|
||||||
"Failed to load configuration with workspace",
|
"Failed to load configuration with workspace",
|
||||||
extra={
|
extra={
|
||||||
@@ -634,32 +646,36 @@ class MemoryConfigRepository:
|
|||||||
},
|
},
|
||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
|
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
|
||||||
"""获取所有配置参数
|
"""获取所有配置参数,包含关联的场景名称
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
workspace_id: 工作空间ID,用于过滤查询结果
|
workspace_id: 工作空间ID,用于过滤查询结果
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[MemoryConfig]: 配置列表
|
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||||
"""
|
"""
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = db.query(MemoryConfig)
|
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
|
||||||
|
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
|
||||||
|
)
|
||||||
|
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||||
|
|
||||||
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
|
results = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||||
|
|
||||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
|
||||||
return configs
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class ModelConfigRepository:
|
|||||||
total = base_query.count()
|
total = base_query.count()
|
||||||
|
|
||||||
# 分页查询
|
# 分页查询
|
||||||
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
|
models = base_query.order_by(desc(ModelConfig.created_at)).offset(
|
||||||
(query.page - 1) * query.pagesize
|
(query.page - 1) * query.pagesize
|
||||||
).limit(query.pagesize).all()
|
).limit(query.pagesize).all()
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ class ModelConfigRepository:
|
|||||||
# 获取总数
|
# 获取总数
|
||||||
total = base_query.count()
|
total = base_query.count()
|
||||||
|
|
||||||
query_results = base_query.order_by(desc(ModelConfig.updated_at)).all()
|
query_results = base_query.order_by(desc(ModelConfig.created_at)).all()
|
||||||
|
|
||||||
provider_groups: Dict[str, List[ModelConfig]] = {}
|
provider_groups: Dict[str, List[ModelConfig]] = {}
|
||||||
for model_config in query_results:
|
for model_config in query_results:
|
||||||
@@ -433,6 +433,7 @@ class ModelConfigRepository:
|
|||||||
ModelConfig.is_public
|
ModelConfig.is_public
|
||||||
),
|
),
|
||||||
ModelBase.provider == provider,
|
ModelBase.provider == provider,
|
||||||
|
ModelConfig.is_active,
|
||||||
~ModelConfig.is_composite
|
~ModelConfig.is_composite
|
||||||
)
|
)
|
||||||
).distinct().all()
|
).distinct().all()
|
||||||
@@ -621,7 +622,7 @@ class ModelBaseRepository:
|
|||||||
if filters:
|
if filters:
|
||||||
q = q.filter(and_(*filters))
|
q = q.filter(and_(*filters))
|
||||||
|
|
||||||
return q.order_by(ModelBase.add_count.desc()).all()
|
return q.order_by(ModelBase.add_count.desc(), ModelBase.created_at.desc()).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(db: Session, data: dict) -> 'ModelBase':
|
def create(db: Session, data: dict) -> 'ModelBase':
|
||||||
@@ -629,6 +630,13 @@ class ModelBaseRepository:
|
|||||||
db.add(model_base)
|
db.add(model_base)
|
||||||
return model_base
|
return model_base
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_by_name_and_provider(db: Session, name: str, provider: str) -> Optional['ModelBase']:
|
||||||
|
return db.query(ModelBase).filter(
|
||||||
|
ModelBase.name == name,
|
||||||
|
ModelBase.provider == provider
|
||||||
|
).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
|
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
|
||||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
||||||
@@ -636,6 +644,17 @@ class ModelBaseRepository:
|
|||||||
return None
|
return None
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
setattr(model_base, key, value)
|
setattr(model_base, key, value)
|
||||||
|
|
||||||
|
# 同步更新绑定的非组合模型配置
|
||||||
|
if any(k in data for k in ['name', 'description', 'logo']):
|
||||||
|
db.query(ModelConfig).filter(
|
||||||
|
ModelConfig.model_id == model_base_id,
|
||||||
|
ModelConfig.is_composite == False
|
||||||
|
).update({
|
||||||
|
k: v for k, v in data.items()
|
||||||
|
if k in ['name', 'description', 'logo']
|
||||||
|
}, synchronize_session=False)
|
||||||
|
|
||||||
return model_base
|
return model_base
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
try:
|
try:
|
||||||
edges: List[dict] = []
|
edges: List[dict] = []
|
||||||
for s in summaries:
|
for s in summaries:
|
||||||
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
chunk_ids = getattr(s, "chunk_ids", []) or []
|
||||||
|
for chunk_id in chunk_ids:
|
||||||
edges.append({
|
edges.append({
|
||||||
"summary_id": s.id,
|
"summary_id": s.id,
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
|
|
||||||
if not edges:
|
if not edges:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = await connector.execute_query(
|
result = await connector.execute_query(
|
||||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||||
edges=edges
|
edges=edges
|
||||||
)
|
)
|
||||||
created = [record.get("uuid") for record in result] if result else []
|
created = [record.get("uuid") for record in result] if result else []
|
||||||
return created
|
return created
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
summaries=flattened
|
summaries=flattened
|
||||||
)
|
)
|
||||||
created_ids = [record.get("uuid") for record in result]
|
created_ids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||||
return created_ids
|
return created_ids
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
e.name_embedding = CASE
|
e.name_embedding = CASE
|
||||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||||
ELSE e.name_embedding END,
|
ELSE e.name_embedding END,
|
||||||
e.fact_summary = CASE
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
// e.fact_summary = CASE
|
||||||
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||||
THEN entity.fact_summary ELSE e.fact_summary END,
|
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||||
|
// THEN entity.fact_summary ELSE e.fact_summary END,
|
||||||
e.connect_strength = CASE
|
e.connect_strength = CASE
|
||||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||||
ELSE CASE
|
ELSE CASE
|
||||||
@@ -321,7 +322,8 @@ RETURN e.id AS id,
|
|||||||
e.description AS description,
|
e.description AS description,
|
||||||
e.aliases AS aliases,
|
e.aliases AS aliases,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
COALESCE(e.fact_summary, '') AS fact_summary,
|
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
|
// COALESCE(e.fact_summary, '') AS fact_summary,
|
||||||
e.connect_strength AS connect_strength,
|
e.connect_strength AS connect_strength,
|
||||||
collect(DISTINCT s.id) AS statement_ids,
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
collect(DISTINCT c.id) AS chunk_ids,
|
collect(DISTINCT c.id) AS chunk_ids,
|
||||||
@@ -877,7 +879,8 @@ RETURN
|
|||||||
CASE
|
CASE
|
||||||
WHEN ms:ExtractedEntity THEN {
|
WHEN ms:ExtractedEntity THEN {
|
||||||
text: ms.name,
|
text: ms.name,
|
||||||
created_at: ms.created_at
|
created_at: ms.created_at,
|
||||||
|
type: "情景记忆"
|
||||||
}
|
}
|
||||||
END
|
END
|
||||||
) AS ExtractedEntity,
|
) AS ExtractedEntity,
|
||||||
@@ -887,7 +890,8 @@ RETURN
|
|||||||
CASE
|
CASE
|
||||||
WHEN n:MemorySummary THEN {
|
WHEN n:MemorySummary THEN {
|
||||||
text: n.content,
|
text: n.content,
|
||||||
created_at: n.created_at
|
created_at: n.created_at,
|
||||||
|
type: "长期沉淀"
|
||||||
}
|
}
|
||||||
END
|
END
|
||||||
) AS MemorySummary,
|
) AS MemorySummary,
|
||||||
@@ -895,7 +899,8 @@ RETURN
|
|||||||
collect(
|
collect(
|
||||||
DISTINCT {
|
DISTINCT {
|
||||||
text: e.statement,
|
text: e.statement,
|
||||||
created_at: e.created_at
|
created_at: e.created_at,
|
||||||
|
type: "情绪记忆"
|
||||||
}
|
}
|
||||||
) AS statement;
|
) AS statement;
|
||||||
"""
|
"""
|
||||||
@@ -999,3 +1004,58 @@ RETURN DISTINCT
|
|||||||
x.statement as statement,x.created_at as created_at
|
x.statement as statement,x.created_at as created_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
Graph_Node_query = """
|
||||||
|
MATCH (n:MemorySummary)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
0 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Dialogue)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
1 AS priority
|
||||||
|
LIMIT 1
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Statement)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
1 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:ExtractedEntity)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
2 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
UNION ALL
|
||||||
|
|
||||||
|
MATCH (n:Chunk)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
3 AS priority
|
||||||
|
LIMIT $limit
|
||||||
|
|
||||||
|
"""
|
||||||
@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
|
|||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
async def save_entities_and_relationships(
|
async def save_entities_and_relationships(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
|
|||||||
'statement': edge.statement,
|
'statement': edge.statement,
|
||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
'created_at': edge.created_at.isoformat(),
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
'expired_at': edge.expired_at.isoformat(),
|
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'end_user_id': edge.end_user_id,
|
'end_user_id': edge.end_user_id,
|
||||||
}
|
}
|
||||||
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
|
|||||||
|
|
||||||
|
|
||||||
async def save_dialog_and_statements_to_neo4j(
|
async def save_dialog_and_statements_to_neo4j(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
@@ -171,40 +172,127 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if successful, False otherwise
|
bool: True if successful, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# Save all dialogue nodes in batch
|
# 定义事务函数,将所有写操作放在一个事务中
|
||||||
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
async def _save_all_in_transaction(tx):
|
||||||
if dialogue_uuids:
|
"""在单个事务中执行所有保存操作,避免死锁"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# 1. Save all dialogue nodes in batch
|
||||||
|
if dialogue_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
|
||||||
|
dialogue_data = [node.model_dump() for node in dialogue_nodes]
|
||||||
|
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||||
|
dialogue_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['dialogues'] = dialogue_uuids
|
||||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||||
else:
|
|
||||||
print("Failed to save dialogues to Neo4j")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Save all chunk nodes in batch
|
# 2. Save all chunk nodes in batch
|
||||||
await save_chunk_nodes(chunk_nodes, connector)
|
if chunk_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
|
||||||
|
chunk_data = [node.model_dump() for node in chunk_nodes]
|
||||||
|
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
|
||||||
|
chunk_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['chunks'] = chunk_uuids
|
||||||
|
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||||
|
|
||||||
# Save all statement nodes in batch
|
# 3. Save all statement nodes in batch
|
||||||
if statement_nodes:
|
if statement_nodes:
|
||||||
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||||
if statement_uuids:
|
statement_data = [node.model_dump() for node in statement_nodes]
|
||||||
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
|
||||||
else:
|
statement_uuids = [record["uuid"] async for record in result]
|
||||||
print("Failed to save statement nodes to Neo4j")
|
results['statements'] = statement_uuids
|
||||||
return False
|
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||||
else:
|
|
||||||
print("No statement nodes to save")
|
|
||||||
|
|
||||||
# Save entities and relationships
|
# 4. Save entities
|
||||||
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
if entity_nodes:
|
||||||
print("Successfully saved entities and relationships to Neo4j")
|
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
|
||||||
|
entity_data = [entity.model_dump() for entity in entity_nodes]
|
||||||
|
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
|
||||||
|
entity_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['entities'] = entity_uuids
|
||||||
|
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
||||||
|
|
||||||
# Save new edges
|
# 5. Create entity relationships
|
||||||
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
if entity_edges:
|
||||||
await save_statement_entity_edges(statement_entity_edges, connector)
|
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
|
||||||
|
relationship_data = []
|
||||||
|
for edge in entity_edges:
|
||||||
|
relationship_data.append({
|
||||||
|
'source_id': edge.source,
|
||||||
|
'target_id': edge.target,
|
||||||
|
'predicate': edge.relation_type,
|
||||||
|
'statement_id': edge.source_statement_id,
|
||||||
|
'value': edge.relation_value,
|
||||||
|
'statement': edge.statement,
|
||||||
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
'run_id': edge.run_id,
|
||||||
|
'end_user_id': edge.end_user_id,
|
||||||
|
})
|
||||||
|
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
|
||||||
|
rel_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['entity_relationships'] = rel_uuids
|
||||||
|
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
|
||||||
|
|
||||||
|
# 6. Save statement-chunk edges
|
||||||
|
if statement_chunk_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
|
||||||
|
sc_edge_data = []
|
||||||
|
for edge in statement_chunk_edges:
|
||||||
|
sc_edge_data.append({
|
||||||
|
"id": edge.id,
|
||||||
|
"source": edge.source,
|
||||||
|
"target": edge.target,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
"run_id": edge.run_id,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
})
|
||||||
|
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
|
||||||
|
sc_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['statement_chunk_edges'] = sc_uuids
|
||||||
|
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
|
||||||
|
|
||||||
|
# 7. Save statement-entity edges
|
||||||
|
if statement_entity_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
|
||||||
|
se_edge_data = []
|
||||||
|
for edge in statement_entity_edges:
|
||||||
|
se_edge_data.append({
|
||||||
|
"source": edge.source,
|
||||||
|
"target": edge.target,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||||
|
"run_id": edge.run_id,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
||||||
|
})
|
||||||
|
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
|
||||||
|
se_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['statement_entity_edges'] = se_uuids
|
||||||
|
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用显式写事务执行所有操作,避免死锁
|
||||||
|
results = await connector.execute_write_transaction(_save_all_in_transaction)
|
||||||
|
summary = {
|
||||||
|
key: len(value)
|
||||||
|
for key, value in results.items()
|
||||||
|
if isinstance(value, (list, tuple, set))
|
||||||
|
}
|
||||||
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
|
logger.debug("Full transaction results: %r", results)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Neo4j integration error: {e}", exc_info=True)
|
||||||
print(f"Neo4j integration error: {e}")
|
print(f"Neo4j integration error: {e}")
|
||||||
print("Continuing without database storage...")
|
print("Continuing without database storage...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
404
api/app/repositories/ontology_class_repository.py
Normal file
404
api/app/repositories/ontology_class_repository.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""本体类型Repository层
|
||||||
|
|
||||||
|
本模块提供本体类型的数据访问层实现。
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologyClassRepository: 本体类型数据访问类
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
|
from app.models.ontology_class import OntologyClass
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_db_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyClassRepository:
|
||||||
|
"""本体类型Repository
|
||||||
|
|
||||||
|
提供本体类型的CRUD操作和权限检查。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
db: SQLAlchemy数据库会话
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
"""初始化Repository
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: SQLAlchemy数据库会话
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(self, class_data: dict, scene_id: UUID) -> OntologyClass:
|
||||||
|
"""创建本体类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_data: 类型数据字典,包含class_name和class_description
|
||||||
|
scene_id: 所属场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OntologyClass: 创建的类型对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> ontology_class = repo.create(
|
||||||
|
... {"class_name": "患者", "class_description": "描述"},
|
||||||
|
... scene_id
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Creating ontology class - "
|
||||||
|
f"name={class_data.get('class_name')}, "
|
||||||
|
f"scene_id={scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ontology_class = OntologyClass(
|
||||||
|
class_name=class_data.get("class_name"),
|
||||||
|
class_description=class_data.get("class_description"),
|
||||||
|
scene_id=scene_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(ontology_class)
|
||||||
|
self.db.flush() # 获取ID但不提交
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ontology class created successfully - "
|
||||||
|
f"class_id={ontology_class.class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ontology_class
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create ontology class: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]:
|
||||||
|
"""根据ID获取类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: 类型ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyClass]: 类型对象,不存在则返回None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> ontology_class = repo.get_by_id(class_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting ontology class by ID: {class_id}")
|
||||||
|
|
||||||
|
ontology_class = self.db.query(OntologyClass).filter(
|
||||||
|
OntologyClass.class_id == class_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if ontology_class:
|
||||||
|
logger.debug(f"Ontology class found: {class_id}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Ontology class not found: {class_id}")
|
||||||
|
|
||||||
|
return ontology_class
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology class by ID: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]:
|
||||||
|
"""根据类型名称和场景ID获取类型(精确匹配)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_name: 类型名称
|
||||||
|
scene_id: 场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyClass]: 类型对象,不存在则返回None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> ontology_class = repo.get_by_name("患者", scene_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}")
|
||||||
|
|
||||||
|
ontology_class = self.db.query(OntologyClass).filter(
|
||||||
|
OntologyClass.class_name == class_name,
|
||||||
|
OntologyClass.scene_id == scene_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if ontology_class:
|
||||||
|
logger.debug(f"Ontology class found: {class_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Ontology class not found: {class_name}")
|
||||||
|
|
||||||
|
return ontology_class
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology class by name: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]:
|
||||||
|
"""根据关键词模糊搜索类型
|
||||||
|
|
||||||
|
使用 LIKE 进行模糊匹配,支持中文和英文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 搜索关键词
|
||||||
|
scene_id: 场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OntologyClass]: 匹配的类型列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> classes = repo.search_by_name("患者", scene_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Searching ontology classes by keyword - "
|
||||||
|
f"keyword={keyword}, scene_id={scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 ilike 进行不区分大小写的模糊匹配
|
||||||
|
classes = self.db.query(OntologyClass).filter(
|
||||||
|
OntologyClass.class_name.ilike(f"%{keyword}%"),
|
||||||
|
OntologyClass.scene_id == scene_id
|
||||||
|
).order_by(
|
||||||
|
OntologyClass.created_at.desc()
|
||||||
|
).all()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(classes)} ontology classes matching keyword '{keyword}' "
|
||||||
|
f"in scene {scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to search ontology classes by keyword: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]:
|
||||||
|
"""获取场景下的所有类型
|
||||||
|
|
||||||
|
按创建时间倒序排列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OntologyClass]: 类型列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> classes = repo.get_by_scene(scene_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting ontology classes by scene: {scene_id}")
|
||||||
|
|
||||||
|
classes = self.db.query(OntologyClass).filter(
|
||||||
|
OntologyClass.scene_id == scene_id
|
||||||
|
).order_by(
|
||||||
|
OntologyClass.created_at.desc()
|
||||||
|
).all()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(classes)} ontology classes in scene {scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology classes by scene: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]:
|
||||||
|
"""更新类型信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: 类型ID
|
||||||
|
update_data: 更新数据字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyClass]: 更新后的类型对象,不存在则返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> ontology_class = repo.update(
|
||||||
|
... class_id,
|
||||||
|
... {"class_name": "新名称"}
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Updating ontology class: {class_id}")
|
||||||
|
|
||||||
|
ontology_class = self.get_by_id(class_id)
|
||||||
|
if not ontology_class:
|
||||||
|
logger.warning(f"Ontology class not found for update: {class_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
if "class_name" in update_data and update_data["class_name"] is not None:
|
||||||
|
ontology_class.class_name = update_data["class_name"]
|
||||||
|
|
||||||
|
if "class_description" in update_data:
|
||||||
|
ontology_class.class_description = update_data["class_description"]
|
||||||
|
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
logger.info(f"Ontology class updated successfully: {class_id}")
|
||||||
|
|
||||||
|
return ontology_class
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update ontology class: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete(self, class_id: UUID) -> bool:
|
||||||
|
"""删除类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: 类型ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除成功返回True,类型不存在返回False
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> success = repo.delete(class_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Deleting ontology class: {class_id}")
|
||||||
|
|
||||||
|
ontology_class = self.get_by_id(class_id)
|
||||||
|
if not ontology_class:
|
||||||
|
logger.warning(f"Ontology class not found for delete: {class_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.db.delete(ontology_class)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
logger.info(f"Ontology class deleted successfully: {class_id}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete ontology class: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool:
|
||||||
|
"""检查类型是否属于指定工作空间(通过场景关联)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: 类型ID
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 属于返回True,否则返回False
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> is_owner = repo.check_ownership(class_id, workspace_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Checking class ownership - "
|
||||||
|
f"class_id={class_id}, workspace_id={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
count = self.db.query(OntologyClass).join(
|
||||||
|
OntologyScene,
|
||||||
|
OntologyClass.scene_id == OntologyScene.scene_id
|
||||||
|
).filter(
|
||||||
|
OntologyClass.class_id == class_id,
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
is_owner = count > 0
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Class ownership check result: {is_owner} - "
|
||||||
|
f"class_id={class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return is_owner
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to check class ownership: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]:
|
||||||
|
"""根据类型ID获取所属场景ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_id: 类型ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[UUID]: 场景ID,类型不存在则返回None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologyClassRepository(db)
|
||||||
|
>>> scene_id = repo.get_scene_id_by_class(class_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting scene ID by class: {class_id}")
|
||||||
|
|
||||||
|
ontology_class = self.get_by_id(class_id)
|
||||||
|
if not ontology_class:
|
||||||
|
logger.debug(f"Class not found: {class_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Found scene ID: {ontology_class.scene_id} for class: {class_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ontology_class.scene_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get scene ID by class: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
439
api/app/repositories/ontology_scene_repository.py
Normal file
439
api/app/repositories/ontology_scene_repository.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""本体场景Repository层
|
||||||
|
|
||||||
|
本模块提供本体场景的数据访问层实现。
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
OntologySceneRepository: 本体场景数据访问类
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_db_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class OntologySceneRepository:
|
||||||
|
"""本体场景Repository
|
||||||
|
|
||||||
|
提供本体场景的CRUD操作和权限检查。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
db: SQLAlchemy数据库会话
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
"""初始化Repository
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: SQLAlchemy数据库会话
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene:
|
||||||
|
"""创建本体场景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_data: 场景数据字典,包含scene_name和scene_description
|
||||||
|
workspace_id: 所属工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OntologyScene: 创建的场景对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scene = repo.create(
|
||||||
|
... {"scene_name": "医疗场景", "scene_description": "描述"},
|
||||||
|
... workspace_id
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Creating ontology scene - "
|
||||||
|
f"name={scene_data.get('scene_name')}, "
|
||||||
|
f"workspace_id={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scene = OntologyScene(
|
||||||
|
scene_name=scene_data.get("scene_name"),
|
||||||
|
scene_description=scene_data.get("scene_description"),
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(scene)
|
||||||
|
self.db.flush() # 获取ID但不提交
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ontology scene created successfully - "
|
||||||
|
f"scene_id={scene.scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return scene
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create ontology scene: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]:
|
||||||
|
"""根据ID获取场景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyScene]: 场景对象,不存在则返回None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scene = repo.get_by_id(scene_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting ontology scene by ID: {scene_id}")
|
||||||
|
|
||||||
|
scene = self.db.query(OntologyScene).filter(
|
||||||
|
OntologyScene.scene_id == scene_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if scene:
|
||||||
|
logger.debug(f"Ontology scene found: {scene_id}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Ontology scene not found: {scene_id}")
|
||||||
|
|
||||||
|
return scene
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology scene by ID: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]:
|
||||||
|
"""根据场景名称和工作空间ID获取场景(精确匹配)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_name: 场景名称
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyScene]: 场景对象,不存在则返回None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scene = repo.get_by_name("医疗场景", workspace_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Getting ontology scene by name - "
|
||||||
|
f"scene_name={scene_name}, workspace_id={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scene = self.db.query(OntologyScene).options(
|
||||||
|
joinedload(OntologyScene.classes)
|
||||||
|
).filter(
|
||||||
|
OntologyScene.scene_name == scene_name,
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if scene:
|
||||||
|
logger.debug(f"Ontology scene found: {scene_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Ontology scene not found: {scene_name}")
|
||||||
|
|
||||||
|
return scene
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology scene by name: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]:
|
||||||
|
"""根据关键词模糊搜索场景
|
||||||
|
|
||||||
|
使用 LIKE 进行模糊匹配,支持中文和英文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keyword: 搜索关键词
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OntologyScene]: 匹配的场景列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scenes = repo.search_by_name("医疗", workspace_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Searching ontology scenes by keyword - "
|
||||||
|
f"keyword={keyword}, workspace_id={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 ilike 进行不区分大小写的模糊匹配
|
||||||
|
scenes = self.db.query(OntologyScene).options(
|
||||||
|
joinedload(OntologyScene.classes)
|
||||||
|
).filter(
|
||||||
|
OntologyScene.scene_name.ilike(f"%{keyword}%"),
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).order_by(
|
||||||
|
OntologyScene.updated_at.desc()
|
||||||
|
).all()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' "
|
||||||
|
f"in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to search ontology scenes by keyword: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple:
|
||||||
|
"""获取工作空间下的所有场景(支持分页)
|
||||||
|
|
||||||
|
使用joinedload预加载classes关系以统计数量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
page: 页码(可选,从1开始)
|
||||||
|
page_size: 每页数量(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (场景列表, 总数量)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scenes, total = repo.get_by_workspace(workspace_id)
|
||||||
|
>>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}")
|
||||||
|
|
||||||
|
# 构建基础查询
|
||||||
|
query = self.db.query(OntologyScene).options(
|
||||||
|
joinedload(OntologyScene.classes)
|
||||||
|
).filter(
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).order_by(
|
||||||
|
OntologyScene.updated_at.desc()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# 如果提供了分页参数,应用分页
|
||||||
|
if page is not None and page_size is not None:
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query = query.offset(offset).limit(page_size)
|
||||||
|
logger.debug(f"Applying pagination: offset={offset}, limit={page_size}")
|
||||||
|
|
||||||
|
scenes = query.all()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return scenes, total
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get ontology scenes by workspace: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]:
|
||||||
|
"""更新场景信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID
|
||||||
|
update_data: 更新数据字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OntologyScene]: 更新后的场景对象,不存在则返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scene = repo.update(
|
||||||
|
... scene_id,
|
||||||
|
... {"scene_name": "新名称"}
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Updating ontology scene: {scene_id}")
|
||||||
|
|
||||||
|
scene = self.get_by_id(scene_id)
|
||||||
|
if not scene:
|
||||||
|
logger.warning(f"Ontology scene not found for update: {scene_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
if "scene_name" in update_data and update_data["scene_name"] is not None:
|
||||||
|
scene.scene_name = update_data["scene_name"]
|
||||||
|
|
||||||
|
if "scene_description" in update_data:
|
||||||
|
scene.scene_description = update_data["scene_description"]
|
||||||
|
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
logger.info(f"Ontology scene updated successfully: {scene_id}")
|
||||||
|
|
||||||
|
return scene
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to update ontology scene: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete(self, scene_id: UUID) -> bool:
|
||||||
|
"""删除场景(级联删除类型)
|
||||||
|
|
||||||
|
依赖数据库级联删除配置(ondelete="CASCADE")。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除成功返回True,场景不存在返回False
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> success = repo.delete(scene_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Deleting ontology scene: {scene_id}")
|
||||||
|
|
||||||
|
scene = self.get_by_id(scene_id)
|
||||||
|
if not scene:
|
||||||
|
logger.warning(f"Ontology scene not found for delete: {scene_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.db.delete(scene)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ontology scene deleted successfully (cascade): {scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to delete ontology scene: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool:
|
||||||
|
"""检查场景是否属于指定工作空间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_id: 场景ID
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 属于返回True,否则返回False
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> is_owner = repo.check_ownership(scene_id, workspace_id)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"Checking scene ownership - "
|
||||||
|
f"scene_id={scene_id}, workspace_id={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
count = self.db.query(OntologyScene).filter(
|
||||||
|
OntologyScene.scene_id == scene_id,
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
is_owner = count > 0
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Scene ownership check result: {is_owner} - "
|
||||||
|
f"scene_id={scene_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return is_owner
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to check scene ownership: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
|
||||||
|
"""获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择)
|
||||||
|
|
||||||
|
这是一个轻量级查询,不加载关联的classes,响应速度快。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: 场景简单列表,每项包含scene_id和scene_name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> repo = OntologySceneRepository(db)
|
||||||
|
>>> scenes = repo.get_simple_list(workspace_id)
|
||||||
|
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
# 只查询需要的字段,不加载关联数据
|
||||||
|
results = self.db.query(
|
||||||
|
OntologyScene.scene_id,
|
||||||
|
OntologyScene.scene_name
|
||||||
|
).filter(
|
||||||
|
OntologyScene.workspace_id == workspace_id
|
||||||
|
).order_by(
|
||||||
|
OntologyScene.updated_at.desc()
|
||||||
|
).all()
|
||||||
|
|
||||||
|
scenes = [
|
||||||
|
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
|
||||||
|
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to get simple scene list: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
@@ -4,7 +4,10 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models.prompt_optimizer_model import (
|
from app.models.prompt_optimizer_model import (
|
||||||
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
PromptOptimizerSession,
|
||||||
|
PromptOptimizerSessionHistory,
|
||||||
|
RoleType,
|
||||||
|
PromptHistory
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository:
|
|||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
|
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
|
||||||
|
session = self.db.query(PromptOptimizerSession).filter(
|
||||||
|
PromptOptimizerSession.id == session_id,
|
||||||
|
).first()
|
||||||
|
return session
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
@@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(session)
|
|
||||||
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
|
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
|
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_session_history(
|
def get_session_history(
|
||||||
@@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository:
|
|||||||
PromptOptimizerSession.id == session_id,
|
PromptOptimizerSession.id == session_id,
|
||||||
PromptOptimizerSession.user_id == user_id
|
PromptOptimizerSession.user_id == user_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||||
PromptOptimizerSessionHistory.session_id == session.id,
|
PromptOptimizerSessionHistory.session_id == session.id,
|
||||||
PromptOptimizerSessionHistory.user_id == user_id
|
PromptOptimizerSessionHistory.user_id == user_id
|
||||||
@@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository:
|
|||||||
PromptOptimizerSession.user_id == user_id,
|
PromptOptimizerSession.user_id == user_id,
|
||||||
PromptOptimizerSession.tenant_id == tenant_id
|
PromptOptimizerSession.tenant_id == tenant_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
||||||
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
||||||
|
|
||||||
message = PromptOptimizerSessionHistory(
|
message = PromptOptimizerSessionHistory(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
@@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository:
|
|||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
self.db.add(message)
|
self.db.add(message)
|
||||||
self.db.commit()
|
|
||||||
return message
|
return message
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
|
||||||
|
"""
|
||||||
|
Get the first user message from a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id (uuid.UUID): The session ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: The content of the first user message, or None if not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
message = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||||
|
PromptOptimizerSessionHistory.session_id == session_id,
|
||||||
|
PromptOptimizerSessionHistory.role == RoleType.USER.value
|
||||||
|
).order_by(
|
||||||
|
PromptOptimizerSessionHistory.created_at.asc()
|
||||||
|
).first()
|
||||||
|
|
||||||
|
return message.content if message else None
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class PromptReleaseRepository:
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
|
||||||
|
prompt_obj = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.session_id == session_id,
|
||||||
|
PromptHistory.is_delete.is_(False)
|
||||||
|
).first()
|
||||||
|
return prompt_obj
|
||||||
|
|
||||||
|
def create_prompt_release(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
title: str,
|
||||||
|
session_id: uuid.UUID,
|
||||||
|
prompt: str,
|
||||||
|
) -> PromptHistory:
|
||||||
|
try:
|
||||||
|
prompt_obj = PromptHistory(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
title=title,
|
||||||
|
session_id=session_id,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
self.db.add(prompt_obj)
|
||||||
|
return prompt_obj
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
|
||||||
|
"""
|
||||||
|
Soft delete a prompt release by setting is_delete flag to True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_obj (PromptHistory): The prompt release object to delete.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt_obj.is_delete = True
|
||||||
|
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
|
||||||
|
"""
|
||||||
|
Get a prompt release by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_id (uuid.UUID): The prompt release ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromptHistory | None: The prompt release object or None if not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt_obj = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.id == prompt_id
|
||||||
|
).first()
|
||||||
|
return prompt_obj
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def count_prompts(self, tenant_id: uuid.UUID) -> int:
|
||||||
|
"""
|
||||||
|
Count total number of non-deleted prompts for a tenant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): The tenant ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total count of prompts.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.tenant_id == tenant_id,
|
||||||
|
PromptHistory.is_delete.is_(False)
|
||||||
|
).count()
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_prompts_paginated(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
offset: int,
|
||||||
|
limit: int
|
||||||
|
) -> list[PromptHistory]:
|
||||||
|
"""
|
||||||
|
Get paginated list of prompt releases for a tenant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): The tenant ID.
|
||||||
|
offset (int): Number of records to skip.
|
||||||
|
limit (int): Maximum number of records to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptHistory]: List of prompt releases.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompts = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.tenant_id == tenant_id,
|
||||||
|
PromptHistory.is_delete.is_(False)
|
||||||
|
).order_by(
|
||||||
|
PromptHistory.created_at.desc()
|
||||||
|
).offset(offset).limit(limit).all()
|
||||||
|
return prompts
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
|
||||||
|
"""
|
||||||
|
Count total number of non-deleted prompts matching keyword for a tenant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): The tenant ID.
|
||||||
|
keyword (str): Search keyword for title.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total count of matching prompts.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.tenant_id == tenant_id,
|
||||||
|
PromptHistory.is_delete.is_(False),
|
||||||
|
PromptHistory.title.ilike(f"%{keyword}%")
|
||||||
|
).count()
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def search_prompts_paginated(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
keyword: str,
|
||||||
|
offset: int,
|
||||||
|
limit: int
|
||||||
|
) -> list[PromptHistory]:
|
||||||
|
"""
|
||||||
|
Search prompt releases by keyword in title with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): The tenant ID.
|
||||||
|
keyword (str): Search keyword for title.
|
||||||
|
offset (int): Number of records to skip.
|
||||||
|
limit (int): Maximum number of records to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptHistory]: List of matching prompt releases.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompts = self.db.query(PromptHistory).filter(
|
||||||
|
PromptHistory.tenant_id == tenant_id,
|
||||||
|
PromptHistory.is_delete.is_(False),
|
||||||
|
PromptHistory.title.ilike(f"%{keyword}%")
|
||||||
|
).order_by(
|
||||||
|
PromptHistory.created_at.desc()
|
||||||
|
).offset(offset).limit(limit).all()
|
||||||
|
return prompts
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
|
|||||||
kb_id: str = Field(..., description="知识库ID")
|
kb_id: str = Field(..., description="知识库ID")
|
||||||
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
||||||
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
||||||
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
||||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
||||||
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
||||||
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
|
|||||||
class Write_UserInput(BaseModel):
|
class Write_UserInput(BaseModel):
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
class AgentMemory_Long_Term(ABC):
|
||||||
|
"""长期记忆配置常量"""
|
||||||
|
STORAGE_NEO4J = "neo4j"
|
||||||
|
STORAGE_RAG = "rag"
|
||||||
|
STRATEGY_AGGREGATE = "aggregate"
|
||||||
|
STRATEGY_CHUNK = "chunk"
|
||||||
|
STRATEGY_TIME = "time"
|
||||||
|
DEFAULT_SCOPE = 6
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -10,7 +12,7 @@ class OptimizationStrategy(str, Enum):
|
|||||||
ACCURACY_FIRST = "accuracy_first"
|
ACCURACY_FIRST = "accuracy_first"
|
||||||
BALANCED = "balanced"
|
BALANCED = "balanced"
|
||||||
class Memory_Reflection(BaseModel):
|
class Memory_Reflection(BaseModel):
|
||||||
config_id: Optional[UUID] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
reflection_enabled: bool
|
reflection_enabled: bool
|
||||||
reflection_period_in_hours: str
|
reflection_period_in_hours: str
|
||||||
reflexion_range: Optional[str] = "partial"
|
reflexion_range: Optional[str] = "partial"
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
|
|||||||
# Composite key identifying a config row
|
# Composite key identifying a config row
|
||||||
class ConfigKey(BaseModel): # 配置参数键模型
|
class ConfigKey(BaseModel): # 配置参数键模型
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID)")
|
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||||
|
|
||||||
@@ -229,26 +229,32 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
|||||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||||
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
||||||
|
|
||||||
|
# 本体场景关联(可选)
|
||||||
|
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
||||||
|
|
||||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
|
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||||
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||||
|
|
||||||
|
|
||||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||||
config_id: uuid.UUID = Field("配置ID", description="配置ID(UUID)")
|
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id: Optional[uuid.UUID] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
|
||||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
|
||||||
|
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id: Optional[uuid.UUID] = None
|
config_id:Union[uuid.UUID, int, str] = None
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
@@ -315,14 +321,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
|||||||
|
|
||||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||||
# 遗忘引擎配置参数更新模型
|
# 遗忘引擎配置参数更新模型
|
||||||
config_id: Optional[uuid.UUID] = None
|
config_id:Union[uuid.UUID, int, str] = None
|
||||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||||
|
|
||||||
|
|
||||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||||
config_id: uuid.UUID = Field(..., description="配置ID(唯一)")
|
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
@@ -330,7 +336,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
|||||||
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: Optional[uuid.UUID] = None
|
config_id: Union[uuid.UUID, int, str] = None
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
apply_id: Optional[str] = None
|
apply_id: Optional[str] = None
|
||||||
|
|
||||||
@@ -406,7 +412,7 @@ class ForgettingConfigResponse(BaseModel):
|
|||||||
"""遗忘引擎配置响应模型"""
|
"""遗忘引擎配置响应模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: uuid.UUID = Field(..., description="配置ID")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
||||||
decay_constant: float = Field(..., description="衰减常数 d")
|
decay_constant: float = Field(..., description="衰减常数 d")
|
||||||
lambda_time: float = Field(..., description="时间衰减参数")
|
lambda_time: float = Field(..., description="时间衰减参数")
|
||||||
lambda_mem: float = Field(..., description="记忆衰减参数")
|
lambda_mem: float = Field(..., description="记忆衰减参数")
|
||||||
@@ -423,8 +429,8 @@ class ForgettingConfigResponse(BaseModel):
|
|||||||
class ForgettingConfigUpdateRequest(BaseModel):
|
class ForgettingConfigUpdateRequest(BaseModel):
|
||||||
"""遗忘引擎配置更新请求模型"""
|
"""遗忘引擎配置更新请求模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: uuid.UUID = Field(..., description="配置ID")
|
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||||
@@ -499,7 +505,7 @@ class ForgettingCurveRequest(BaseModel):
|
|||||||
|
|
||||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||||
config_id: Optional[uuid.UUID] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
||||||
|
|
||||||
|
|
||||||
class ForgettingCurveResponse(BaseModel):
|
class ForgettingCurveResponse(BaseModel):
|
||||||
|
|||||||
@@ -3,14 +3,12 @@ from typing import Optional, List, Dict, Any
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
schema_logger = get_business_logger()
|
schema_logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ModelConfig Schemas
|
# ModelConfig Schemas
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
"""模型配置基础Schema"""
|
"""模型配置基础Schema"""
|
||||||
@@ -22,6 +20,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
is_public: bool = Field(False, description="是否公开")
|
is_public: bool = Field(False, description="是否公开")
|
||||||
|
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyCreateNested(BaseModel):
|
class ApiKeyCreateNested(BaseModel):
|
||||||
@@ -44,13 +43,14 @@ class ModelConfigCreate(ModelConfigBase):
|
|||||||
class CompositeModelCreate(BaseModel):
|
class CompositeModelCreate(BaseModel):
|
||||||
"""创建组合模型Schema"""
|
"""创建组合模型Schema"""
|
||||||
name: str = Field(..., description="组合模型名称", max_length=255)
|
name: str = Field(..., description="组合模型名称", max_length=255)
|
||||||
type: ModelType = Field(..., description="模型类型")
|
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||||
description: Optional[str] = Field(None, description="模型描述")
|
description: Optional[str] = Field(None, description="模型描述")
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
is_public: bool = Field(False, description="是否公开")
|
is_public: bool = Field(False, description="是否公开")
|
||||||
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
|
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
|
||||||
|
load_balance_strategy: Optional[str] = Field(default=LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigUpdate(BaseModel):
|
class ModelConfigUpdate(BaseModel):
|
||||||
|
|||||||
461
api/app/schemas/ontology_schemas.py
Normal file
461
api/app/schemas/ontology_schemas.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
"""本体提取API的请求和响应模型
|
||||||
|
|
||||||
|
本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。
|
||||||
|
|
||||||
|
Classes:
|
||||||
|
ExtractionRequest: 本体提取请求模型
|
||||||
|
ExtractionResponse: 本体提取响应模型
|
||||||
|
ExportRequest: OWL文件导出请求模型
|
||||||
|
ExportResponse: OWL文件导出响应模型
|
||||||
|
OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳)
|
||||||
|
SceneCreateRequest: 场景创建请求模型
|
||||||
|
SceneUpdateRequest: 场景更新请求模型
|
||||||
|
SceneResponse: 场景响应模型
|
||||||
|
SceneListResponse: 场景列表响应模型
|
||||||
|
ClassCreateRequest: 类型创建请求模型
|
||||||
|
ClassUpdateRequest: 类型更新请求模型
|
||||||
|
ClassResponse: 类型响应模型
|
||||||
|
ClassListResponse: 类型列表响应模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||||
|
|
||||||
|
from app.core.memory.models.ontology_models import OntologyClass
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionRequest(BaseModel):
|
||||||
|
"""本体提取请求模型
|
||||||
|
|
||||||
|
用于POST /api/ontology/extract端点的请求体。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scenario: 场景描述文本,不能为空
|
||||||
|
domain: 可选的领域提示(如Healthcare, Education等)
|
||||||
|
llm_id: LLM模型ID,必须提供
|
||||||
|
scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> request = ExtractionRequest(
|
||||||
|
... scenario="医院管理患者记录...",
|
||||||
|
... domain="Healthcare",
|
||||||
|
... llm_id="550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
... scene_id="660e8400-e29b-41d4-a716-446655440000"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
scenario: str = Field(..., description="场景描述文本", min_length=1)
|
||||||
|
domain: Optional[str] = Field(None, description="可选的领域提示")
|
||||||
|
llm_id: str = Field(..., description="LLM模型ID")
|
||||||
|
scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景")
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionResponse(BaseModel):
|
||||||
|
"""本体提取响应模型
|
||||||
|
|
||||||
|
用于POST /api/ontology/extract端点的响应体。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
classes: 提取的本体类列表
|
||||||
|
domain: 识别的领域
|
||||||
|
extracted_count: 提取的类数量
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = ExtractionResponse(
|
||||||
|
... classes=[...],
|
||||||
|
... domain="Healthcare",
|
||||||
|
... extracted_count=7
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表")
|
||||||
|
domain: str = Field(..., description="识别的领域")
|
||||||
|
extracted_count: int = Field(..., description="提取的类数量")
|
||||||
|
|
||||||
|
|
||||||
|
class ExportRequest(BaseModel):
|
||||||
|
"""OWL文件导出请求模型
|
||||||
|
|
||||||
|
用于POST /api/ontology/export端点的请求体。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
classes: 要导出的本体类列表
|
||||||
|
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
|
||||||
|
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> request = ExportRequest(
|
||||||
|
... classes=[...],
|
||||||
|
... format="rdfxml",
|
||||||
|
... include_metadata=True
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
|
||||||
|
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
|
||||||
|
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
|
||||||
|
|
||||||
|
|
||||||
|
class ExportResponse(BaseModel):
|
||||||
|
"""OWL文件导出响应模型
|
||||||
|
|
||||||
|
用于POST /api/ontology/export端点的响应体。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
owl_content: OWL文件内容
|
||||||
|
format: 导出格式
|
||||||
|
classes_count: 导出的类数量
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = ExportResponse(
|
||||||
|
... owl_content="<?xml version='1.0'?>...",
|
||||||
|
... format="rdfxml",
|
||||||
|
... classes_count=7
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
owl_content: str = Field(..., description="OWL文件内容")
|
||||||
|
format: str = Field(..., description="导出格式")
|
||||||
|
classes_count: int = Field(..., description="导出的类数量")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyResultResponse(BaseModel):
|
||||||
|
"""本体提取结果响应模型
|
||||||
|
|
||||||
|
用于返回数据库中存储的提取结果,时间戳为毫秒级。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: 结果ID (UUID)
|
||||||
|
scenario: 场景描述文本
|
||||||
|
domain: 领域
|
||||||
|
classes_json: 提取的本体类数据(JSON格式)
|
||||||
|
extracted_count: 提取的类数量
|
||||||
|
user_id: 用户ID
|
||||||
|
created_at: 创建时间(毫秒时间戳)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = OntologyResultResponse(
|
||||||
|
... id=uuid.uuid4(),
|
||||||
|
... scenario="医院管理患者记录...",
|
||||||
|
... domain="Healthcare",
|
||||||
|
... classes_json={"classes": [...]},
|
||||||
|
... extracted_count=7,
|
||||||
|
... user_id=123,
|
||||||
|
... created_at=datetime.now()
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
id: UUID = Field(..., description="结果ID")
|
||||||
|
scenario: str = Field(..., description="场景描述文本")
|
||||||
|
domain: Optional[str] = Field(None, description="领域")
|
||||||
|
classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)")
|
||||||
|
extracted_count: int = Field(..., description="提取的类数量")
|
||||||
|
user_id: Optional[int] = Field(None, description="用户ID")
|
||||||
|
created_at: datetime.datetime = Field(..., description="创建时间")
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
"""将创建时间序列化为毫秒时间戳"""
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 本体场景相关 Schema ====================
|
||||||
|
|
||||||
|
class SceneCreateRequest(BaseModel):
|
||||||
|
"""场景创建请求模型
|
||||||
|
|
||||||
|
用于创建新的本体场景。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scene_name: 场景名称,必填,1-200字符
|
||||||
|
scene_description: 场景描述,可选
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> request = SceneCreateRequest(
|
||||||
|
... scene_name="医疗场景",
|
||||||
|
... scene_description="用于医疗领域的本体建模"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称")
|
||||||
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
|
||||||
|
|
||||||
|
class SceneUpdateRequest(BaseModel):
|
||||||
|
"""场景更新请求模型
|
||||||
|
|
||||||
|
用于更新已有本体场景信息。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scene_name: 场景名称,可选,1-200字符
|
||||||
|
scene_description: 场景描述,可选
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> request = SceneUpdateRequest(
|
||||||
|
... scene_name="更新后的场景名称",
|
||||||
|
... scene_description="更新后的描述"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称")
|
||||||
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
|
||||||
|
|
||||||
|
class SceneResponse(BaseModel):
|
||||||
|
"""场景响应模型
|
||||||
|
|
||||||
|
用于返回本体场景信息。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scene_id: 场景ID
|
||||||
|
scene_name: 场景名称
|
||||||
|
scene_description: 场景描述
|
||||||
|
type_num: 类型数量
|
||||||
|
workspace_id: 所属工作空间ID
|
||||||
|
created_at: 创建时间(毫秒时间戳)
|
||||||
|
updated_at: 更新时间(毫秒时间戳)
|
||||||
|
classes_count: 类型数量
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = SceneResponse(
|
||||||
|
... scene_id=uuid.uuid4(),
|
||||||
|
... scene_name="医疗场景",
|
||||||
|
... scene_description="用于医疗领域的本体建模",
|
||||||
|
... type_num=0,
|
||||||
|
... workspace_id=uuid.uuid4(),
|
||||||
|
... created_at=datetime.now(),
|
||||||
|
... updated_at=datetime.now(),
|
||||||
|
... classes_count=5
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
scene_id: UUID = Field(..., description="场景ID")
|
||||||
|
scene_name: str = Field(..., description="场景名称")
|
||||||
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
type_num: int = Field(..., description="类型数量")
|
||||||
|
entity_type: Optional[List[str]] = Field(None, description="实体类型列表(最多3个class_name)")
|
||||||
|
workspace_id: UUID = Field(..., description="所属工作空间ID")
|
||||||
|
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||||
|
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||||
|
classes_count: int = Field(0, description="类型数量")
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
"""将创建时间序列化为毫秒时间戳"""
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("updated_at", when_used="json")
|
||||||
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
|
"""将更新时间序列化为毫秒时间戳"""
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationInfo(BaseModel):
|
||||||
|
"""分页信息模型
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
page: 当前页码
|
||||||
|
pagesize: 每页数量
|
||||||
|
total: 总数量
|
||||||
|
hasnext: 是否有下一页
|
||||||
|
"""
|
||||||
|
page: int = Field(..., description="当前页码")
|
||||||
|
pagesize: int = Field(..., description="每页数量")
|
||||||
|
total: int = Field(..., description="总数量")
|
||||||
|
hasnext: bool = Field(..., description="是否有下一页")
|
||||||
|
|
||||||
|
|
||||||
|
class SceneListResponse(BaseModel):
|
||||||
|
"""场景列表响应模型(支持分页)
|
||||||
|
|
||||||
|
用于返回本体场景列表。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
items: 场景列表
|
||||||
|
page: 分页信息(可选,分页时返回)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # 不分页
|
||||||
|
>>> response = SceneListResponse(
|
||||||
|
... items=[scene1, scene2]
|
||||||
|
... )
|
||||||
|
>>> # 分页
|
||||||
|
>>> response = SceneListResponse(
|
||||||
|
... items=[scene1, scene2, ...],
|
||||||
|
... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True)
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
items: List[SceneResponse] = Field(..., description="场景列表")
|
||||||
|
page: Optional[PaginationInfo] = Field(None, description="分页信息")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 本体类型相关 Schema ====================
|
||||||
|
|
||||||
|
class ClassItem(BaseModel):
|
||||||
|
"""单个类型信息模型
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_name: 类型名称,必填,1-200字符
|
||||||
|
class_description: 类型描述,可选
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> item = ClassItem(
|
||||||
|
... class_name="患者",
|
||||||
|
... class_description="医院患者信息"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
class_name: str = Field(..., min_length=1, max_length=200, description="类型名称")
|
||||||
|
class_description: Optional[str] = Field(None, description="类型描述")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassCreateRequest(BaseModel):
|
||||||
|
"""类型创建请求模型(统一使用列表形式)
|
||||||
|
|
||||||
|
通过列表中元素数量决定创建模式:
|
||||||
|
- 列表包含 1 个元素:单个创建
|
||||||
|
- 列表包含多个元素:批量创建
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
scene_id: 所属场景ID,必填
|
||||||
|
classes: 类型列表,必填,至少包含 1 个元素
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# 单个创建(列表中 1 个元素)
|
||||||
|
>>> request = ClassCreateRequest(
|
||||||
|
... scene_id=uuid.uuid4(),
|
||||||
|
... classes=[
|
||||||
|
... ClassItem(class_name="患者", class_description="医院患者信息")
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
|
||||||
|
# 批量创建(列表中多个元素)
|
||||||
|
>>> request = ClassCreateRequest(
|
||||||
|
... scene_id=uuid.uuid4(),
|
||||||
|
... classes=[
|
||||||
|
... ClassItem(class_name="患者", class_description="医院患者信息"),
|
||||||
|
... ClassItem(class_name="医生", class_description="医院医生信息"),
|
||||||
|
... ClassItem(class_name="药品", class_description="医院药品信息")
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
scene_id: UUID = Field(..., description="所属场景ID")
|
||||||
|
classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassUpdateRequest(BaseModel):
|
||||||
|
"""类型更新请求模型
|
||||||
|
|
||||||
|
用于更新已有本体类型信息。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_name: 类型名称,可选,1-200字符
|
||||||
|
class_description: 类型描述,可选
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> request = ClassUpdateRequest(
|
||||||
|
... class_name="更新后的类型名称",
|
||||||
|
... class_description="更新后的描述"
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称")
|
||||||
|
class_description: Optional[str] = Field(None, description="类型描述")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassResponse(BaseModel):
|
||||||
|
"""类型响应模型
|
||||||
|
|
||||||
|
用于返回本体类型信息。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_id: 类型ID
|
||||||
|
class_name: 类型名称
|
||||||
|
class_description: 类型描述
|
||||||
|
scene_id: 所属场景ID
|
||||||
|
created_at: 创建时间(毫秒时间戳)
|
||||||
|
updated_at: 更新时间(毫秒时间戳)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = ClassResponse(
|
||||||
|
... class_id=uuid.uuid4(),
|
||||||
|
... class_name="患者",
|
||||||
|
... class_description="医院患者信息",
|
||||||
|
... scene_id=uuid.uuid4(),
|
||||||
|
... created_at=datetime.now(),
|
||||||
|
... updated_at=datetime.now()
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
class_id: UUID = Field(..., description="类型ID")
|
||||||
|
class_name: str = Field(..., description="类型名称")
|
||||||
|
class_description: Optional[str] = Field(None, description="类型描述")
|
||||||
|
scene_id: UUID = Field(..., description="所属场景ID")
|
||||||
|
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||||
|
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
"""将创建时间序列化为毫秒时间戳"""
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("updated_at", when_used="json")
|
||||||
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
|
"""将更新时间序列化为毫秒时间戳"""
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassBatchCreateResponse(BaseModel):
|
||||||
|
"""批量创建类型响应模型
|
||||||
|
|
||||||
|
用于返回批量创建的结果统计和详情。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total: 总共尝试创建的数量
|
||||||
|
success_count: 成功创建的数量
|
||||||
|
failed_count: 失败的数量
|
||||||
|
items: 成功创建的类型列表
|
||||||
|
errors: 失败的错误信息列表(可选)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = ClassBatchCreateResponse(
|
||||||
|
... total=3,
|
||||||
|
... success_count=2,
|
||||||
|
... failed_count=1,
|
||||||
|
... items=[class1, class2],
|
||||||
|
... errors=["创建类型 '药品' 失败: 类型名称已存在"]
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
total: int = Field(..., description="总共尝试创建的数量")
|
||||||
|
success_count: int = Field(..., description="成功创建的数量")
|
||||||
|
failed_count: int = Field(0, description="失败的数量")
|
||||||
|
items: List[ClassResponse] = Field(..., description="成功创建的类型列表")
|
||||||
|
errors: Optional[List[str]] = Field(None, description="失败的错误信息列表")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassListResponse(BaseModel):
|
||||||
|
"""类型列表响应模型
|
||||||
|
|
||||||
|
用于返回本体类型列表。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total: 总数量
|
||||||
|
scene_id: 所属场景ID
|
||||||
|
scene_name: 场景名称
|
||||||
|
scene_description: 场景描述
|
||||||
|
items: 类型列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> response = ClassListResponse(
|
||||||
|
... total=3,
|
||||||
|
... scene_id=uuid.uuid4(),
|
||||||
|
... scene_name="医疗场景",
|
||||||
|
... scene_description="用于医疗领域的本体建模",
|
||||||
|
... items=[class1, class2, class3]
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
total: int = Field(..., description="总数量")
|
||||||
|
scene_id: UUID = Field(..., description="所属场景ID")
|
||||||
|
scene_name: str = Field(..., description="场景名称")
|
||||||
|
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||||
|
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||||
@@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptSaveRequest(BaseModel):
|
||||||
|
session_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="Session ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
title: str = Field(
|
||||||
|
...,
|
||||||
|
description="Prompt Title"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
description="Optimized prompt content"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptOptModelSet(BaseModel):
|
class PromptOptModelSet(BaseModel):
|
||||||
id: UUID | None = Field(
|
id: UUID | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -171,7 +171,14 @@ class AppChatService:
|
|||||||
self.conversation_service.save_conversation_messages(
|
self.conversation_service.save_conversation_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"]
|
assistant_message=result["content"],
|
||||||
|
meta_data={
|
||||||
|
"usage": result.get("usage", {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
})
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
@@ -310,6 +317,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 流式调用 Agent
|
# 流式调用 Agent
|
||||||
full_content = ""
|
full_content = ""
|
||||||
|
total_tokens = 0
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -320,9 +328,12 @@ class AppChatService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag
|
||||||
):
|
):
|
||||||
full_content += chunk
|
if isinstance(chunk, int):
|
||||||
# 发送消息块事件
|
total_tokens = chunk
|
||||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
else:
|
||||||
|
full_content += chunk
|
||||||
|
# 发送消息块事件
|
||||||
|
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
@@ -339,7 +350,7 @@ class AppChatService:
|
|||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data={
|
||||||
"model": api_key_obj.model_name,
|
"model": api_key_obj.model_name,
|
||||||
"usage": {}
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -416,7 +427,11 @@ class AppChatService:
|
|||||||
meta_data={
|
meta_data={
|
||||||
"mode": result.get("mode"),
|
"mode": result.get("mode"),
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
"sub_results": result.get("sub_results")
|
"usage": result.get("usage", {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
})
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -458,6 +473,7 @@ class AppChatService:
|
|||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
full_content = ""
|
full_content = ""
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
@@ -474,16 +490,26 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
yield event
|
if "sub_usage" in event:
|
||||||
# 尝试提取内容(用于保存)
|
if "data:" in event:
|
||||||
if "data:" in event:
|
try:
|
||||||
try:
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
data = json.loads(data_line)
|
||||||
data = json.loads(data_line)
|
if "total_tokens" in data:
|
||||||
if "content" in data:
|
total_tokens += data["total_tokens"]
|
||||||
full_content += data["content"]
|
except:
|
||||||
except:
|
pass
|
||||||
pass
|
else:
|
||||||
|
yield event
|
||||||
|
# 尝试提取内容(用于保存)
|
||||||
|
if "data:" in event:
|
||||||
|
try:
|
||||||
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
|
data = json.loads(data_line)
|
||||||
|
if "content" in data:
|
||||||
|
full_content += data["content"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
@@ -499,7 +525,12 @@ class AppChatService:
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data={
|
||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class AppStatisticsService:
|
|||||||
daily_tokens[date_str] = 0
|
daily_tokens[date_str] = 0
|
||||||
daily_tokens[date_str] += int(tokens)
|
daily_tokens[date_str] += int(tokens)
|
||||||
|
|
||||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||||
total = sum(row["tokens"] for row in daily_data)
|
total = sum(row["count"] for row in daily_data)
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
return {"daily": daily_data, "total": total}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""会话服务"""
|
"""会话服务"""
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
@@ -298,7 +299,8 @@ class ConversationService:
|
|||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
assistant_message: str
|
assistant_message: str,
|
||||||
|
meta_data: Optional[dict] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save a pair of user and assistant messages to the conversation.
|
Save a pair of user and assistant messages to the conversation.
|
||||||
@@ -307,6 +309,7 @@ class ConversationService:
|
|||||||
conversation_id (uuid.UUID): Conversation UUID.
|
conversation_id (uuid.UUID): Conversation UUID.
|
||||||
user_message (str): User's message content.
|
user_message (str): User's message content.
|
||||||
assistant_message (str): Assistant's response content.
|
assistant_message (str): Assistant's response content.
|
||||||
|
meta_data (Optional[dict]): Optional metadata for the messages.
|
||||||
"""
|
"""
|
||||||
self.add_message(
|
self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -317,7 +320,8 @@ class ConversationService:
|
|||||||
self.add_message(
|
self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message
|
content=assistant_message,
|
||||||
|
meta_data=meta_data
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -526,12 +530,12 @@ class ConversationService:
|
|||||||
takeaways=[],
|
takeaways=[],
|
||||||
info_score=0,
|
info_score=0,
|
||||||
)
|
)
|
||||||
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
system_prompt = f.read()
|
system_prompt = f.read()
|
||||||
rendered_system_message = Template(system_prompt).render()
|
rendered_system_message = Template(system_prompt).render()
|
||||||
|
|
||||||
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
user_prompt = f.read()
|
user_prompt = f.read()
|
||||||
rendered_user_message = Template(user_prompt).render(
|
rendered_user_message = Template(user_prompt).render(
|
||||||
language=language,
|
language=language,
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
result = task_service.get_task_memory_read_result(task.id)
|
result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
|
if memory_content:
|
||||||
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"content_length": len(str(memory_content))
|
"content_length": len(str(memory_content))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
@@ -442,7 +443,14 @@ class DraftRunService:
|
|||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"],
|
assistant_message=result["content"],
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
meta_data={
|
||||||
|
"usage": result.get("usage", {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
})
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
@@ -649,6 +657,7 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 9. 流式调用 Agent
|
# 9. 流式调用 Agent
|
||||||
full_content = ""
|
full_content = ""
|
||||||
|
total_tokens = 0
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -659,14 +668,22 @@ class DraftRunService:
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag
|
||||||
):
|
):
|
||||||
full_content += chunk
|
if isinstance(chunk, int):
|
||||||
# 发送消息块事件
|
total_tokens = chunk
|
||||||
yield self._format_sse_event("message", {
|
else:
|
||||||
"content": chunk
|
full_content += chunk
|
||||||
})
|
# 发送消息块事件
|
||||||
|
yield self._format_sse_event("message", {
|
||||||
|
"content": chunk
|
||||||
|
})
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
if sub_agent:
|
||||||
|
yield self._format_sse_event("sub_usage", {
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
})
|
||||||
|
|
||||||
# 10. 保存会话消息
|
# 10. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
@@ -674,7 +691,10 @@ class DraftRunService:
|
|||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=full_content,
|
assistant_message=full_content,
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
meta_data={
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. 发送结束事件
|
# 11. 发送结束事件
|
||||||
@@ -898,6 +918,7 @@ class DraftRunService:
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
assistant_message: str,
|
assistant_message: str,
|
||||||
|
meta_data: dict,
|
||||||
app_id: Optional[uuid.UUID] = None,
|
app_id: Optional[uuid.UUID] = None,
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -909,6 +930,7 @@ class DraftRunService:
|
|||||||
assistant_message: AI 回复消息
|
assistant_message: AI 回复消息
|
||||||
app_id: 应用ID(未使用,保留用于兼容性)
|
app_id: 应用ID(未使用,保留用于兼容性)
|
||||||
user_id: 用户ID(未使用,保留用于兼容性)
|
user_id: 用户ID(未使用,保留用于兼容性)
|
||||||
|
meta_data: token消耗
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
@@ -927,7 +949,8 @@ class DraftRunService:
|
|||||||
conversation_service.add_message(
|
conversation_service.add_message(
|
||||||
conversation_id=conv_uuid,
|
conversation_id=conv_uuid,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message
|
content=assistant_message,
|
||||||
|
meta_data=meta_data
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -17,12 +17,15 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
class EmotionSuggestion(BaseModel):
|
class EmotionSuggestion(BaseModel):
|
||||||
"""情绪建议模型"""
|
"""情绪建议模型"""
|
||||||
type: str = Field(..., description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management")
|
type: str = Field(...,
|
||||||
|
description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management")
|
||||||
title: str = Field(..., description="建议标题")
|
title: str = Field(..., description="建议标题")
|
||||||
content: str = Field(..., description="建议内容")
|
content: str = Field(..., description="建议内容")
|
||||||
priority: str = Field(..., description="优先级:high/medium/low")
|
priority: str = Field(..., description="优先级:high/medium/low")
|
||||||
@@ -37,33 +40,33 @@ class EmotionSuggestionsResponse(BaseModel):
|
|||||||
|
|
||||||
class EmotionAnalyticsService:
|
class EmotionAnalyticsService:
|
||||||
"""情绪分析服务
|
"""情绪分析服务
|
||||||
|
|
||||||
提供情绪数据的分析和统计功能,包括:
|
提供情绪数据的分析和统计功能,包括:
|
||||||
- 情绪标签统计
|
- 情绪标签统计
|
||||||
- 情绪词云数据
|
- 情绪词云数据
|
||||||
- 情绪健康指数计算
|
- 情绪健康指数计算
|
||||||
- 个性化情绪建议生成
|
- 个性化情绪建议生成
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
emotion_repo: 情绪数据仓储实例
|
emotion_repo: 情绪数据仓储实例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化情绪分析服务"""
|
"""初始化情绪分析服务"""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
self.emotion_repo = EmotionRepository(connector)
|
self.emotion_repo = EmotionRepository(connector)
|
||||||
logger.info("情绪分析服务初始化完成")
|
logger.info("情绪分析服务初始化完成")
|
||||||
|
|
||||||
async def get_emotion_tags(
|
async def get_emotion_tags(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
limit: int = 10
|
limit: int = 10
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""获取情绪标签统计
|
"""获取情绪标签统计
|
||||||
|
|
||||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||||
确保返回所有6个情绪维度(joy、sadness、anger、fear、surprise、neutral),
|
确保返回所有6个情绪维度(joy、sadness、anger、fear、surprise、neutral),
|
||||||
即使某些维度没有数据也会返回count=0的记录。
|
即使某些维度没有数据也会返回count=0的记录。
|
||||||
@@ -71,8 +74,8 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
|
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
|
||||||
f"start={start_date}, end={end_date}, limit={limit}")
|
f"start={start_date}, end={end_date}, limit={limit}")
|
||||||
|
|
||||||
# 调用仓储层查询
|
# 调用仓储层查询
|
||||||
tags = await self.emotion_repo.get_emotion_tags(
|
tags = await self.emotion_repo.get_emotion_tags(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -81,13 +84,13 @@ class EmotionAnalyticsService:
|
|||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义所有6个情绪维度
|
# 定义所有6个情绪维度
|
||||||
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||||
|
|
||||||
# 将查询结果转换为字典,方便查找
|
# 将查询结果转换为字典,方便查找
|
||||||
tags_dict = {tag["emotion_type"]: tag for tag in tags}
|
tags_dict = {tag["emotion_type"]: tag for tag in tags}
|
||||||
|
|
||||||
# 补全缺失的情绪维度
|
# 补全缺失的情绪维度
|
||||||
complete_tags = []
|
complete_tags = []
|
||||||
for emotion in all_emotion_types:
|
for emotion in all_emotion_types:
|
||||||
@@ -101,52 +104,52 @@ class EmotionAnalyticsService:
|
|||||||
"percentage": 0.0,
|
"percentage": 0.0,
|
||||||
"avg_intensity": 0.0
|
"avg_intensity": 0.0
|
||||||
})
|
})
|
||||||
|
|
||||||
# 计算总数
|
# 计算总数
|
||||||
total_count = sum(tag["count"] for tag in complete_tags)
|
total_count = sum(tag["count"] for tag in complete_tags)
|
||||||
|
|
||||||
# 如果有数据,重新计算百分比(因为补全了0值项)
|
# 如果有数据,重新计算百分比(因为补全了0值项)
|
||||||
if total_count > 0:
|
if total_count > 0:
|
||||||
for tag in complete_tags:
|
for tag in complete_tags:
|
||||||
if tag["count"] > 0:
|
if tag["count"] > 0:
|
||||||
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
|
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
|
||||||
|
|
||||||
# 构建时间范围信息
|
# 构建时间范围信息
|
||||||
time_range = {}
|
time_range = {}
|
||||||
if start_date:
|
if start_date:
|
||||||
time_range["start_date"] = start_date
|
time_range["start_date"] = start_date
|
||||||
if end_date:
|
if end_date:
|
||||||
time_range["end_date"] = end_date
|
time_range["end_date"] = end_date
|
||||||
|
|
||||||
# 格式化响应
|
# 格式化响应
|
||||||
response = {
|
response = {
|
||||||
"tags": complete_tags,
|
"tags": complete_tags,
|
||||||
"total_count": total_count,
|
"total_count": total_count,
|
||||||
"time_range": time_range if time_range else None
|
"time_range": time_range if time_range else None
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
|
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True)
|
logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_emotion_wordcloud(
|
async def get_emotion_wordcloud(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
limit: int = 50
|
limit: int = 50
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""获取情绪词云数据
|
"""获取情绪词云数据
|
||||||
|
|
||||||
查询情绪关键词及其频率,用于生成词云可视化。
|
查询情绪关键词及其频率,用于生成词云可视化。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 宿主ID(用户组ID)
|
end_user_id: 宿主ID(用户组ID)
|
||||||
emotion_type: 可选的情绪类型过滤
|
emotion_type: 可选的情绪类型过滤
|
||||||
limit: 返回关键词的最大数量
|
limit: 返回关键词的最大数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含情绪词云数据的响应:
|
Dict: 包含情绪词云数据的响应:
|
||||||
- keywords: 关键词列表
|
- keywords: 关键词列表
|
||||||
@@ -154,39 +157,39 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}")
|
logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}")
|
||||||
|
|
||||||
# 调用仓储层查询
|
# 调用仓储层查询
|
||||||
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
emotion_type=emotion_type,
|
emotion_type=emotion_type,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算总关键词数量
|
# 计算总关键词数量
|
||||||
total_keywords = len(keywords)
|
total_keywords = len(keywords)
|
||||||
|
|
||||||
# 格式化响应
|
# 格式化响应
|
||||||
response = {
|
response = {
|
||||||
"keywords": keywords,
|
"keywords": keywords,
|
||||||
"total_keywords": total_keywords
|
"total_keywords": total_keywords
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}")
|
logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True)
|
logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""计算积极率
|
"""计算积极率
|
||||||
|
|
||||||
根据情绪类型分类正面、负面和中性情绪,计算积极率。
|
根据情绪类型分类正面、负面和中性情绪,计算积极率。
|
||||||
公式:(正面数 / (正面数 + 负面数)) * 100
|
公式:(正面数 / (正面数 + 负面数)) * 100
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
emotions: 情绪数据列表,每个包含 emotion_type 字段
|
emotions: 情绪数据列表,每个包含 emotion_type 字段
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含积极率计算结果:
|
Dict: 包含积极率计算结果:
|
||||||
- score: 积极率分数(0-100)
|
- score: 积极率分数(0-100)
|
||||||
@@ -197,38 +200,38 @@ class EmotionAnalyticsService:
|
|||||||
# 定义情绪分类
|
# 定义情绪分类
|
||||||
positive_emotions = {'joy', 'surprise'}
|
positive_emotions = {'joy', 'surprise'}
|
||||||
negative_emotions = {'sadness', 'anger', 'fear'}
|
negative_emotions = {'sadness', 'anger', 'fear'}
|
||||||
|
|
||||||
# 统计各类情绪数量
|
# 统计各类情绪数量
|
||||||
positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions)
|
positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions)
|
||||||
negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions)
|
negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions)
|
||||||
neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral')
|
neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral')
|
||||||
|
|
||||||
# 计算积极率
|
# 计算积极率
|
||||||
total_non_neutral = positive_count + negative_count
|
total_non_neutral = positive_count + negative_count
|
||||||
if total_non_neutral > 0:
|
if total_non_neutral > 0:
|
||||||
score = (positive_count / total_non_neutral) * 100
|
score = (positive_count / total_non_neutral) * 100
|
||||||
else:
|
else:
|
||||||
score = 50.0 # 如果没有非中性情绪,默认为50
|
score = 50.0 # 如果没有非中性情绪,默认为50
|
||||||
|
|
||||||
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
|
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
|
||||||
f"neutral={neutral_count}, score={score:.2f}")
|
f"neutral={neutral_count}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
"positive_count": positive_count,
|
"positive_count": positive_count,
|
||||||
"negative_count": negative_count,
|
"negative_count": negative_count,
|
||||||
"neutral_count": neutral_count
|
"neutral_count": neutral_count
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""计算稳定性
|
"""计算稳定性
|
||||||
|
|
||||||
基于情绪强度的标准差计算情绪稳定性。
|
基于情绪强度的标准差计算情绪稳定性。
|
||||||
公式:(1 - min(std_deviation, 1.0)) * 100
|
公式:(1 - min(std_deviation, 1.0)) * 100
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
emotions: 情绪数据列表,每个包含 emotion_intensity 字段
|
emotions: 情绪数据列表,每个包含 emotion_intensity 字段
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含稳定性计算结果:
|
Dict: 包含稳定性计算结果:
|
||||||
- score: 稳定性分数(0-100)
|
- score: 稳定性分数(0-100)
|
||||||
@@ -236,7 +239,7 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
# 提取所有情绪强度
|
# 提取所有情绪强度
|
||||||
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
|
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
|
||||||
|
|
||||||
# 计算标准差
|
# 计算标准差
|
||||||
if len(intensities) >= 2:
|
if len(intensities) >= 2:
|
||||||
std_deviation = statistics.stdev(intensities)
|
std_deviation = statistics.stdev(intensities)
|
||||||
@@ -244,29 +247,29 @@ class EmotionAnalyticsService:
|
|||||||
std_deviation = 0.0 # 只有一个数据点,标准差为0
|
std_deviation = 0.0 # 只有一个数据点,标准差为0
|
||||||
else:
|
else:
|
||||||
std_deviation = 0.0 # 没有数据,标准差为0
|
std_deviation = 0.0 # 没有数据,标准差为0
|
||||||
|
|
||||||
# 计算稳定性分数
|
# 计算稳定性分数
|
||||||
# 标准差越小,稳定性越高
|
# 标准差越小,稳定性越高
|
||||||
score = (1 - min(std_deviation, 1.0)) * 100
|
score = (1 - min(std_deviation, 1.0)) * 100
|
||||||
|
|
||||||
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
|
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
|
||||||
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
|
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
"std_deviation": round(std_deviation, 3)
|
"std_deviation": round(std_deviation, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""计算恢复力
|
"""计算恢复力
|
||||||
|
|
||||||
分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。
|
分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。
|
||||||
公式:(负面到正面转换次数 / 总负面情绪数) * 100
|
公式:(负面到正面转换次数 / 总负面情绪数) * 100
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段
|
emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段
|
||||||
应该按时间顺序排列
|
应该按时间顺序排列
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含恢复力计算结果:
|
Dict: 包含恢复力计算结果:
|
||||||
- score: 恢复力分数(0-100)
|
- score: 恢复力分数(0-100)
|
||||||
@@ -275,24 +278,24 @@ class EmotionAnalyticsService:
|
|||||||
# 定义情绪分类
|
# 定义情绪分类
|
||||||
positive_emotions = {'joy', 'surprise'}
|
positive_emotions = {'joy', 'surprise'}
|
||||||
negative_emotions = {'sadness', 'anger', 'fear'}
|
negative_emotions = {'sadness', 'anger', 'fear'}
|
||||||
|
|
||||||
# 统计负面到正面的转换次数
|
# 统计负面到正面的转换次数
|
||||||
recovery_count = 0
|
recovery_count = 0
|
||||||
negative_count = 0
|
negative_count = 0
|
||||||
|
|
||||||
for i in range(len(emotions)):
|
for i in range(len(emotions)):
|
||||||
current_emotion = emotions[i].get('emotion_type')
|
current_emotion = emotions[i].get('emotion_type')
|
||||||
|
|
||||||
# 统计负面情绪总数
|
# 统计负面情绪总数
|
||||||
if current_emotion in negative_emotions:
|
if current_emotion in negative_emotions:
|
||||||
negative_count += 1
|
negative_count += 1
|
||||||
|
|
||||||
# 检查下一个情绪是否为正面
|
# 检查下一个情绪是否为正面
|
||||||
if i + 1 < len(emotions):
|
if i + 1 < len(emotions):
|
||||||
next_emotion = emotions[i + 1].get('emotion_type')
|
next_emotion = emotions[i + 1].get('emotion_type')
|
||||||
if next_emotion in positive_emotions:
|
if next_emotion in positive_emotions:
|
||||||
recovery_count += 1
|
recovery_count += 1
|
||||||
|
|
||||||
# 计算恢复力分数
|
# 计算恢复力分数
|
||||||
if negative_count > 0:
|
if negative_count > 0:
|
||||||
recovery_rate = recovery_count / negative_count
|
recovery_rate = recovery_count / negative_count
|
||||||
@@ -301,28 +304,28 @@ class EmotionAnalyticsService:
|
|||||||
# 如果没有负面情绪,恢复力设为100(最佳状态)
|
# 如果没有负面情绪,恢复力设为100(最佳状态)
|
||||||
recovery_rate = 1.0
|
recovery_rate = 1.0
|
||||||
score = 100.0
|
score = 100.0
|
||||||
|
|
||||||
logger.debug(f"恢复力计算: negative_count={negative_count}, "
|
logger.debug(f"恢复力计算: negative_count={negative_count}, "
|
||||||
f"recovery_count={recovery_count}, score={score:.2f}")
|
f"recovery_count={recovery_count}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
"recovery_rate": round(recovery_rate, 3)
|
"recovery_rate": round(recovery_rate, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
async def calculate_emotion_health_index(
|
async def calculate_emotion_health_index(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
time_range: str = "30d"
|
time_range: str = "30d"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""计算情绪健康指数
|
"""计算情绪健康指数
|
||||||
|
|
||||||
综合积极率、稳定性和恢复力计算情绪健康指数。
|
综合积极率、稳定性和恢复力计算情绪健康指数。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 宿主ID(用户组ID)
|
end_user_id: 宿主ID(用户组ID)
|
||||||
time_range: 时间范围(7d/30d/90d)
|
time_range: 时间范围(7d/30d/90d)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含情绪健康指数的完整响应:
|
Dict: 包含情绪健康指数的完整响应:
|
||||||
- health_score: 综合健康分数(0-100)
|
- health_score: 综合健康分数(0-100)
|
||||||
@@ -336,13 +339,13 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}")
|
logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}")
|
||||||
|
|
||||||
# 获取时间范围内的情绪数据
|
# 获取时间范围内的情绪数据
|
||||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
time_range=time_range
|
time_range=time_range
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果没有数据,返回默认值
|
# 如果没有数据,返回默认值
|
||||||
if not emotions:
|
if not emotions:
|
||||||
logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据")
|
logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据")
|
||||||
@@ -357,20 +360,20 @@ class EmotionAnalyticsService:
|
|||||||
"emotion_distribution": {},
|
"emotion_distribution": {},
|
||||||
"time_range": time_range
|
"time_range": time_range
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算各维度指标
|
# 计算各维度指标
|
||||||
positivity_rate = self._calculate_positivity_rate(emotions)
|
positivity_rate = self._calculate_positivity_rate(emotions)
|
||||||
stability = self._calculate_stability(emotions)
|
stability = self._calculate_stability(emotions)
|
||||||
resilience = self._calculate_resilience(emotions)
|
resilience = self._calculate_resilience(emotions)
|
||||||
|
|
||||||
# 计算综合健康分数
|
# 计算综合健康分数
|
||||||
# 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
|
# 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
|
||||||
health_score = (
|
health_score = (
|
||||||
positivity_rate["score"] * 0.4 +
|
positivity_rate["score"] * 0.4 +
|
||||||
stability["score"] * 0.3 +
|
stability["score"] * 0.3 +
|
||||||
resilience["score"] * 0.3
|
resilience["score"] * 0.3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 确定健康等级
|
# 确定健康等级
|
||||||
if health_score >= 80:
|
if health_score >= 80:
|
||||||
level = "优秀"
|
level = "优秀"
|
||||||
@@ -380,13 +383,13 @@ class EmotionAnalyticsService:
|
|||||||
level = "一般"
|
level = "一般"
|
||||||
else:
|
else:
|
||||||
level = "较差"
|
level = "较差"
|
||||||
|
|
||||||
# 统计情绪分布
|
# 统计情绪分布
|
||||||
emotion_distribution = {}
|
emotion_distribution = {}
|
||||||
for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']:
|
for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']:
|
||||||
count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type)
|
count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type)
|
||||||
emotion_distribution[emotion_type] = count
|
emotion_distribution[emotion_type] = count
|
||||||
|
|
||||||
# 格式化响应
|
# 格式化响应
|
||||||
response = {
|
response = {
|
||||||
"health_score": round(health_score, 2),
|
"health_score": round(health_score, 2),
|
||||||
@@ -399,22 +402,22 @@ class EmotionAnalyticsService:
|
|||||||
"emotion_distribution": emotion_distribution,
|
"emotion_distribution": emotion_distribution,
|
||||||
"time_range": time_range
|
"time_range": time_range
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}")
|
logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True)
|
logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""分析情绪模式
|
"""分析情绪模式
|
||||||
|
|
||||||
识别主要负面情绪、情绪触发因素和波动时段。
|
识别主要负面情绪、情绪触发因素和波动时段。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段
|
emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含情绪模式分析结果:
|
Dict: 包含情绪模式分析结果:
|
||||||
- dominant_negative_emotion: 主要负面情绪类型
|
- dominant_negative_emotion: 主要负面情绪类型
|
||||||
@@ -422,19 +425,19 @@ class EmotionAnalyticsService:
|
|||||||
- emotion_volatility: 情绪波动性(高/中/低)
|
- emotion_volatility: 情绪波动性(高/中/低)
|
||||||
"""
|
"""
|
||||||
negative_emotions = {'sadness', 'anger', 'fear'}
|
negative_emotions = {'sadness', 'anger', 'fear'}
|
||||||
|
|
||||||
# 统计负面情绪分布
|
# 统计负面情绪分布
|
||||||
negative_emotion_counts = {}
|
negative_emotion_counts = {}
|
||||||
for emotion in emotions:
|
for emotion in emotions:
|
||||||
emotion_type = emotion.get('emotion_type')
|
emotion_type = emotion.get('emotion_type')
|
||||||
if emotion_type in negative_emotions:
|
if emotion_type in negative_emotions:
|
||||||
negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1
|
negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1
|
||||||
|
|
||||||
# 识别主要负面情绪
|
# 识别主要负面情绪
|
||||||
dominant_negative_emotion = None
|
dominant_negative_emotion = None
|
||||||
if negative_emotion_counts:
|
if negative_emotion_counts:
|
||||||
dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get)
|
dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get)
|
||||||
|
|
||||||
# 识别高强度情绪(强度 >= 0.7)
|
# 识别高强度情绪(强度 >= 0.7)
|
||||||
high_intensity_emotions = [
|
high_intensity_emotions = [
|
||||||
{
|
{
|
||||||
@@ -445,7 +448,7 @@ class EmotionAnalyticsService:
|
|||||||
for e in emotions
|
for e in emotions
|
||||||
if e.get('emotion_intensity', 0) >= 0.7
|
if e.get('emotion_intensity', 0) >= 0.7
|
||||||
]
|
]
|
||||||
|
|
||||||
# 评估情绪波动性
|
# 评估情绪波动性
|
||||||
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
|
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
|
||||||
if len(intensities) >= 2:
|
if len(intensities) >= 2:
|
||||||
@@ -458,29 +461,29 @@ class EmotionAnalyticsService:
|
|||||||
volatility = "低"
|
volatility = "低"
|
||||||
else:
|
else:
|
||||||
volatility = "未知"
|
volatility = "未知"
|
||||||
|
|
||||||
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
|
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
|
||||||
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
|
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"dominant_negative_emotion": dominant_negative_emotion,
|
"dominant_negative_emotion": dominant_negative_emotion,
|
||||||
"high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个
|
"high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个
|
||||||
"emotion_volatility": volatility
|
"emotion_volatility": volatility
|
||||||
}
|
}
|
||||||
|
|
||||||
async def generate_emotion_suggestions(
|
async def generate_emotion_suggestions(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""生成个性化情绪建议
|
"""生成个性化情绪建议
|
||||||
|
|
||||||
基于情绪健康数据和用户画像生成个性化建议。
|
基于情绪健康数据和用户画像生成个性化建议。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 宿主ID(用户组ID)
|
end_user_id: 宿主ID(用户组ID)
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含个性化建议的响应:
|
Dict: 包含个性化建议的响应:
|
||||||
- health_summary: 健康状态摘要
|
- health_summary: 健康状态摘要
|
||||||
@@ -488,17 +491,17 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
||||||
|
|
||||||
# 1. 从 end_user_id 获取关联的 memory_config_id
|
# 1. 从 end_user_id 获取关联的 memory_config_id
|
||||||
llm_client = None
|
llm_client = None
|
||||||
try:
|
try:
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
if config_id is not None:
|
if config_id is not None:
|
||||||
from app.services.memory_config_service import (
|
from app.services.memory_config_service import (
|
||||||
MemoryConfigService,
|
MemoryConfigService,
|
||||||
@@ -513,35 +516,35 @@ class EmotionAnalyticsService:
|
|||||||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
||||||
|
|
||||||
# 2. 获取情绪健康数据
|
# 2. 获取情绪健康数据
|
||||||
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
||||||
|
|
||||||
# 3. 获取情绪数据用于模式分析
|
# 3. 获取情绪数据用于模式分析
|
||||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
time_range="30d"
|
time_range="30d"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 分析情绪模式
|
# 4. 分析情绪模式
|
||||||
patterns = self._analyze_emotion_patterns(emotions)
|
patterns = self._analyze_emotion_patterns(emotions)
|
||||||
|
|
||||||
# 5. 获取用户画像数据(简化版,直接从Neo4j获取)
|
# 5. 获取用户画像数据(简化版,直接从Neo4j获取)
|
||||||
user_profile = await self._get_simple_user_profile(end_user_id)
|
user_profile = await self._get_simple_user_profile(end_user_id)
|
||||||
|
|
||||||
# 6. 构建LLM prompt
|
# 6. 构建LLM prompt
|
||||||
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
||||||
|
|
||||||
# 7. 调用LLM生成建议(使用配置中的LLM)
|
# 7. 调用LLM生成建议(使用配置中的LLM)
|
||||||
if llm_client is None:
|
if llm_client is None:
|
||||||
# 无法获取配置时,抛出错误而不是使用默认配置
|
# 无法获取配置时,抛出错误而不是使用默认配置
|
||||||
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
||||||
|
|
||||||
# 将 prompt 转换为 messages 格式
|
# 将 prompt 转换为 messages 格式
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
]
|
]
|
||||||
|
|
||||||
# 8. 使用结构化输出直接获取 Pydantic 模型
|
# 8. 使用结构化输出直接获取 Pydantic 模型
|
||||||
try:
|
try:
|
||||||
suggestions_response = await llm_client.response_structured(
|
suggestions_response = await llm_client.response_structured(
|
||||||
@@ -552,7 +555,7 @@ class EmotionAnalyticsService:
|
|||||||
logger.error(f"LLM 结构化输出失败: {str(e)}")
|
logger.error(f"LLM 结构化输出失败: {str(e)}")
|
||||||
# 返回默认建议
|
# 返回默认建议
|
||||||
suggestions_response = self._get_default_suggestions(health_data)
|
suggestions_response = self._get_default_suggestions(health_data)
|
||||||
|
|
||||||
# 8. 验证建议数量(3-5条)
|
# 8. 验证建议数量(3-5条)
|
||||||
if len(suggestions_response.suggestions) < 3:
|
if len(suggestions_response.suggestions) < 3:
|
||||||
logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}")
|
logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}")
|
||||||
@@ -560,7 +563,7 @@ class EmotionAnalyticsService:
|
|||||||
elif len(suggestions_response.suggestions) > 5:
|
elif len(suggestions_response.suggestions) > 5:
|
||||||
logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}")
|
logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}")
|
||||||
suggestions_response.suggestions = suggestions_response.suggestions[:5]
|
suggestions_response.suggestions = suggestions_response.suggestions[:5]
|
||||||
|
|
||||||
# 9. 格式化响应
|
# 9. 格式化响应
|
||||||
response = {
|
response = {
|
||||||
"health_summary": suggestions_response.health_summary,
|
"health_summary": suggestions_response.health_summary,
|
||||||
@@ -575,26 +578,26 @@ class EmotionAnalyticsService:
|
|||||||
for s in suggestions_response.suggestions
|
for s in suggestions_response.suggestions
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}")
|
logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True)
|
logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]:
|
async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]:
|
||||||
"""获取简化的用户画像数据
|
"""获取简化的用户画像数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 用户ID
|
end_user_id: 用户ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 用户画像数据
|
Dict: 用户画像数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
|
|
||||||
# 查询用户的实体和标签
|
# 查询用户的实体和标签
|
||||||
query = """
|
query = """
|
||||||
MATCH (e:Entity)
|
MATCH (e:Entity)
|
||||||
@@ -603,59 +606,59 @@ class EmotionAnalyticsService:
|
|||||||
ORDER BY e.created_at DESC
|
ORDER BY e.created_at DESC
|
||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
entities = await connector.execute_query(query, end_user_id=end_user_id)
|
entities = await connector.execute_query(query, end_user_id=end_user_id)
|
||||||
|
|
||||||
# 提取兴趣标签
|
# 提取兴趣标签
|
||||||
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
||||||
# 后期会引入用户的习惯。。
|
# 后期会引入用户的习惯。。
|
||||||
return {
|
return {
|
||||||
"interests": interests if interests else ["未知"]
|
"interests": interests if interests else ["未知"]
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取用户画像失败: {str(e)}")
|
logger.error(f"获取用户画像失败: {str(e)}")
|
||||||
return {"interests": ["未知"]}
|
return {"interests": ["未知"]}
|
||||||
|
|
||||||
async def _build_suggestion_prompt(
|
async def _build_suggestion_prompt(
|
||||||
self,
|
self,
|
||||||
health_data: Dict[str, Any],
|
health_data: Dict[str, Any],
|
||||||
patterns: Dict[str, Any],
|
patterns: Dict[str, Any],
|
||||||
user_profile: Dict[str, Any]
|
user_profile: Dict[str, Any]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建情绪建议生成的prompt
|
"""构建情绪建议生成的prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
health_data: 情绪健康数据
|
health_data: 情绪健康数据
|
||||||
patterns: 情绪模式分析结果
|
patterns: 情绪模式分析结果
|
||||||
user_profile: 用户画像数据
|
user_profile: 用户画像数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: LLM prompt
|
str: LLM prompt
|
||||||
"""
|
"""
|
||||||
from app.core.memory.utils.prompt.prompt_utils import (
|
from app.core.memory.utils.prompt.prompt_utils import (
|
||||||
render_emotion_suggestions_prompt,
|
render_emotion_suggestions_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await render_emotion_suggestions_prompt(
|
prompt = await render_emotion_suggestions_prompt(
|
||||||
health_data=health_data,
|
health_data=health_data,
|
||||||
patterns=patterns,
|
patterns=patterns,
|
||||||
user_profile=user_profile
|
user_profile=user_profile
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse:
|
def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse:
|
||||||
"""获取默认建议(当LLM调用失败时使用)
|
"""获取默认建议(当LLM调用失败时使用)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
health_data: 情绪健康数据
|
health_data: 情绪健康数据
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmotionSuggestionsResponse: 默认建议
|
EmotionSuggestionsResponse: 默认建议
|
||||||
"""
|
"""
|
||||||
health_score = health_data.get('health_score', 0)
|
health_score = health_data.get('health_score', 0)
|
||||||
|
|
||||||
if health_score >= 80:
|
if health_score >= 80:
|
||||||
summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。"
|
summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。"
|
||||||
elif health_score >= 60:
|
elif health_score >= 60:
|
||||||
@@ -664,7 +667,7 @@ class EmotionAnalyticsService:
|
|||||||
summary = "您的情绪健康需要关注,建议采取一些改善措施。"
|
summary = "您的情绪健康需要关注,建议采取一些改善措施。"
|
||||||
else:
|
else:
|
||||||
summary = "您的情绪健康需要重点关注,建议寻求专业帮助。"
|
summary = "您的情绪健康需要重点关注,建议寻求专业帮助。"
|
||||||
|
|
||||||
suggestions = [
|
suggestions = [
|
||||||
EmotionSuggestion(
|
EmotionSuggestion(
|
||||||
type="emotion_balance",
|
type="emotion_balance",
|
||||||
@@ -700,54 +703,54 @@ class EmotionAnalyticsService:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
return EmotionSuggestionsResponse(
|
return EmotionSuggestionsResponse(
|
||||||
health_summary=summary,
|
health_summary=summary,
|
||||||
suggestions=suggestions
|
suggestions=suggestions
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_cached_suggestions(
|
async def get_cached_suggestions(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""从 Redis 缓存获取个性化情绪建议
|
"""从 Redis 缓存获取个性化情绪建议
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 宿主ID(用户组ID)
|
end_user_id: 宿主ID(用户组ID)
|
||||||
db: 数据库会话(保留参数以保持接口兼容性)
|
db: 数据库会话(保留参数以保持接口兼容性)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 缓存的建议数据,如果不存在或已过期返回 None
|
Dict: 缓存的建议数据,如果不存在或已过期返回 None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||||
|
|
||||||
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
|
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
|
||||||
|
|
||||||
# 从 Redis 获取缓存
|
# 从 Redis 获取缓存
|
||||||
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
|
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
|
||||||
|
|
||||||
if cached_data is None:
|
if cached_data is None:
|
||||||
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
|
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
|
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
|
||||||
return cached_data
|
return cached_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
|
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def save_suggestions_cache(
|
async def save_suggestions_cache(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
suggestions_data: Dict[str, Any],
|
suggestions_data: Dict[str, Any],
|
||||||
db: Session,
|
db: Session,
|
||||||
expires_hours: int = 24
|
expires_hours: int = 24
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存建议到 Redis 缓存
|
"""保存建议到 Redis 缓存
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 宿主ID(用户组ID)
|
end_user_id: 宿主ID(用户组ID)
|
||||||
suggestions_data: 建议数据
|
suggestions_data: 建议数据
|
||||||
@@ -756,24 +759,24 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||||
|
|
||||||
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||||
|
|
||||||
# 计算过期时间(秒)
|
# 计算过期时间(秒)
|
||||||
expire_seconds = expires_hours * 3600
|
expire_seconds = expires_hours * 3600
|
||||||
|
|
||||||
# 保存到 Redis
|
# 保存到 Redis
|
||||||
success = await EmotionMemoryCache.set_emotion_suggestions(
|
success = await EmotionMemoryCache.set_emotion_suggestions(
|
||||||
user_id=end_user_id,
|
user_id=end_user_id,
|
||||||
suggestions_data=suggestions_data,
|
suggestions_data=suggestions_data,
|
||||||
expire=expire_seconds
|
expire=expire_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"建议缓存保存成功: user={end_user_id}")
|
logger.info(f"建议缓存保存成功: user={end_user_id}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"建议缓存保存失败: user={end_user_id}")
|
logger.warning(f"建议缓存保存失败: user={end_user_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
|
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
|
||||||
# 不抛出异常,缓存失败不应影响主流程
|
# 不抛出异常,缓存失败不应影响主流程
|
||||||
@@ -4,7 +4,7 @@ import uuid
|
|||||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
@@ -727,9 +727,12 @@ class HandoffsService:
|
|||||||
|
|
||||||
# 提取响应
|
# 提取响应
|
||||||
response_content = ""
|
response_content = ""
|
||||||
|
total_tokens = 0
|
||||||
for msg in result.get("messages", []):
|
for msg in result.get("messages", []):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_content = msg.content
|
response_content = msg.content
|
||||||
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -737,7 +740,12 @@ class HandoffsService:
|
|||||||
"active_agent": result.get("active_agent"),
|
"active_agent": result.get("active_agent"),
|
||||||
"response": response_content,
|
"response": response_content,
|
||||||
"message_count": len(result.get("messages", [])),
|
"message_count": len(result.get("messages", [])),
|
||||||
"handoff_count": result.get("handoff_count", 0)
|
"handoff_count": result.get("handoff_count", 0),
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
@@ -830,6 +838,12 @@ class HandoffsService:
|
|||||||
|
|
||||||
# 捕获 LLM 结束事件,输出收集到的工具调用
|
# 捕获 LLM 结束事件,输出收集到的工具调用
|
||||||
elif kind == "on_chat_model_end":
|
elif kind == "on_chat_model_end":
|
||||||
|
output_message = event.get("data", {}).get("output", {})
|
||||||
|
if isinstance(output_message, AIMessageChunk):
|
||||||
|
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
|
||||||
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||||
|
0) if response_meta else 0
|
||||||
|
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
|
||||||
if collected_tool_calls:
|
if collected_tool_calls:
|
||||||
# 找到参数最完整的 transfer 工具调用
|
# 找到参数最完整的 transfer 工具调用
|
||||||
best_tc = None
|
best_tc = None
|
||||||
|
|||||||
@@ -334,7 +334,9 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
|
print(100*'-')
|
||||||
|
print(langchain_messages)
|
||||||
|
print(100*'-')
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
|
|||||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
current_user: User
|
current_user: User
|
||||||
) -> List[EndUser]:
|
) -> List[EndUser]:
|
||||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||||
|
|
||||||
|
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||||
|
"""
|
||||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
|||||||
app_ids = [app.id for app in apps_orm]
|
app_ids = [app.id for app in apps_orm]
|
||||||
|
|
||||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||||
|
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||||
from app.models.end_user_model import EndUser as EndUserModel
|
from app.models.end_user_model import EndUser as EndUserModel
|
||||||
|
from sqlalchemy import desc, nullslast
|
||||||
end_users_orm = db.query(EndUserModel).filter(
|
end_users_orm = db.query(EndUserModel).filter(
|
||||||
EndUserModel.app_id.in_(app_ids)
|
EndUserModel.app_id.in_(app_ids)
|
||||||
|
).order_by(
|
||||||
|
nullslast(desc(EndUserModel.updated_at)),
|
||||||
|
desc(EndUserModel.id)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 转换为 Pydantic 模型(只在需要时转换)
|
# 转换为 Pydantic 模型(只在需要时转换)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -88,51 +89,44 @@ class WorkspaceAppService:
|
|||||||
|
|
||||||
for release in app_releases:
|
for release in app_releases:
|
||||||
memory_content = self._extract_memory_content(release.config)
|
memory_content = self._extract_memory_content(release.config)
|
||||||
|
|
||||||
|
|
||||||
if memory_content and memory_content in processed_configs:
|
if memory_content and memory_content in processed_configs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
release_info = {
|
release_info = {
|
||||||
"app_id": str(release.app_id),
|
"app_id": str(release.app_id),
|
||||||
"config": memory_content
|
"config": memory_content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if memory_content:
|
if memory_content:
|
||||||
processed_configs.add(memory_content)
|
processed_configs.add(memory_content)
|
||||||
memory_config_info = self._get_memory_config(memory_content)
|
memory_config_info = self._get_memory_config(memory_content)
|
||||||
|
|
||||||
if memory_config_info:
|
if memory_config_info:
|
||||||
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
|
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
|
||||||
app_info["memory_configs"].append(memory_config_info)
|
app_info["memory_configs"].append(memory_config_info)
|
||||||
|
|
||||||
app_info["releases"].append(release_info)
|
app_info["releases"].append(release_info)
|
||||||
|
|
||||||
def _extract_memory_content(self, config: Any) -> str:
|
def _extract_memory_content(self, config: Any) -> str:
|
||||||
"""Extract memory_comtent from config"""
|
"""Extract memory_comtent from config"""
|
||||||
if not config or not isinstance(config, dict):
|
if not config or not isinstance(config, dict):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
memory_obj = config.get('memory')
|
memory_obj = config.get('memory')
|
||||||
if memory_obj and isinstance(memory_obj, dict):
|
if memory_obj and isinstance(memory_obj, dict):
|
||||||
return memory_obj.get('memory_content')
|
return memory_obj.get('memory_content')
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
||||||
"""Retrieve memory_config information based on memory_content"""
|
"""Retrieve memory_config information based on memory_content"""
|
||||||
try:
|
try:
|
||||||
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
memory_content = resolve_config_id(memory_content, self.db)
|
||||||
|
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content))
|
||||||
|
|
||||||
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
|
|
||||||
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
|
|
||||||
# if memory_config_result is None:
|
|
||||||
# return None
|
|
||||||
|
|
||||||
if memory_config_result:
|
if memory_config_result:
|
||||||
return {
|
return {
|
||||||
"config_id": memory_config_result.config_id,
|
"config_id": memory_content,
|
||||||
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
||||||
"iteration_period": memory_config_result.iteration_period,
|
"iteration_period": memory_config_result.iteration_period,
|
||||||
"reflexion_range": memory_config_result.reflexion_range,
|
"reflexion_range": memory_config_result.reflexion_range,
|
||||||
@@ -144,20 +138,22 @@ class WorkspaceAppService:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"查询memory_config失败,memory_content: {memory_content}, 错误: {str(e)}")
|
api_logger.warning(f"查询memory_config失败,memory_content: {memory_content}, 错误: {str(e)}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None:
|
def _process_end_users(self, app: App, app_info: Dict[str, Any]) -> None:
|
||||||
"""Processing end-user information for applications"""
|
"""Processing end-user information for applications"""
|
||||||
end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all()
|
end_users = self.db.query(EndUser).filter(EndUser.app_id == app.id).all()
|
||||||
|
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
end_user_info = {
|
end_user_info = {
|
||||||
"id": str(end_user.id),
|
"id": str(end_user.id),
|
||||||
"app_id": str(end_user.app_id)
|
"app_id": str(end_user.app_id)
|
||||||
}
|
}
|
||||||
app_info["end_users"].append(end_user_info)
|
app_info["end_users"].append(end_user_info)
|
||||||
|
print(100*'-')
|
||||||
|
print(app_info)
|
||||||
|
|
||||||
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
|
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
Read the reflection time of end users
|
Read the reflection time of end users
|
||||||
@@ -176,7 +172,7 @@ class WorkspaceAppService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"读取用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}")
|
api_logger.error(f"读取用户反思时间失败,end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_end_user_reflection_time(self, end_user_id: str) -> bool:
|
def update_end_user_reflection_time(self, end_user_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Update the reflection time of end users to the current time
|
Update the reflection time of end users to the current time
|
||||||
@@ -189,7 +185,7 @@ class WorkspaceAppService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||||
if end_user:
|
if end_user:
|
||||||
end_user.reflection_time = datetime.now()
|
end_user.reflection_time = datetime.now()
|
||||||
@@ -207,7 +203,7 @@ class WorkspaceAppService:
|
|||||||
|
|
||||||
class MemoryReflectionService:
|
class MemoryReflectionService:
|
||||||
"""Memory reflection service category"""
|
"""Memory reflection service category"""
|
||||||
|
|
||||||
def __init__(self,db: Session = Depends(get_db)):
|
def __init__(self,db: Session = Depends(get_db)):
|
||||||
self.db=db
|
self.db=db
|
||||||
|
|
||||||
@@ -252,22 +248,22 @@ class MemoryReflectionService:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"config_data": config_data
|
"config_data": config_data
|
||||||
}
|
}
|
||||||
|
|
||||||
async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]:
|
async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Starting Reflection from Configuration Data
|
Starting Reflection from Configuration Data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_data: Configure data dictionary, including reflective configuration information
|
config_data: Configure data dictionary, including reflective configuration information
|
||||||
end_user_id: end_user_id
|
end_user_id: end_user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Reflect on the execution results
|
Reflect on the execution results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config_id = config_data.get("config_id")
|
config_id = config_data.get("config_id")
|
||||||
api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}")
|
api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}")
|
||||||
|
|
||||||
|
|
||||||
if not config_data.get("enable_self_reflexion", False):
|
if not config_data.get("enable_self_reflexion", False):
|
||||||
return {
|
return {
|
||||||
@@ -277,7 +273,7 @@ class MemoryReflectionService:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"config_data": config_data
|
"config_data": config_data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
config_data_id=config_data['config_id']
|
config_data_id=config_data['config_id']
|
||||||
reflection_config=WorkspaceAppService(self.db)._get_memory_config(config_data_id)
|
reflection_config=WorkspaceAppService(self.db)._get_memory_config(config_data_id)
|
||||||
@@ -290,7 +286,7 @@ class MemoryReflectionService:
|
|||||||
# 检查是否需要执行反思
|
# 检查是否需要执行反思
|
||||||
should_execute = False
|
should_execute = False
|
||||||
hours_diff = 0
|
hours_diff = 0
|
||||||
|
|
||||||
if current_reflection_time is None:
|
if current_reflection_time is None:
|
||||||
# 首次执行反思
|
# 首次执行反思
|
||||||
should_execute = True
|
should_execute = True
|
||||||
@@ -302,11 +298,11 @@ class MemoryReflectionService:
|
|||||||
reflection_time = datetime.fromisoformat(current_reflection_time)
|
reflection_time = datetime.fromisoformat(current_reflection_time)
|
||||||
else:
|
else:
|
||||||
reflection_time = current_reflection_time
|
reflection_time = current_reflection_time
|
||||||
|
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
time_diff = current_time - reflection_time
|
time_diff = current_time - reflection_time
|
||||||
hours_diff = int(time_diff.total_seconds() / 3600)
|
hours_diff = int(time_diff.total_seconds() / 3600)
|
||||||
|
|
||||||
# 检查是否达到反思周期
|
# 检查是否达到反思周期
|
||||||
if hours_diff >= iteration_period:
|
if hours_diff >= iteration_period:
|
||||||
should_execute = True
|
should_execute = True
|
||||||
@@ -316,7 +312,7 @@ class MemoryReflectionService:
|
|||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
|
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
|
||||||
should_execute = True
|
should_execute = True
|
||||||
|
|
||||||
if should_execute:
|
if should_execute:
|
||||||
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
|
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
|
||||||
# 3. 执行反思引擎
|
# 3. 执行反思引擎
|
||||||
@@ -349,7 +345,7 @@ class MemoryReflectionService:
|
|||||||
"next_reflection_in_hours": iteration_period - hours_diff
|
"next_reflection_in_hours": iteration_period - hours_diff
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
config_id = config_data.get("config_id", "unknown")
|
config_id = config_data.get("config_id", "unknown")
|
||||||
api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
|
api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||||
@@ -360,7 +356,7 @@ class MemoryReflectionService:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"config_data": config_data
|
"config_data": config_data
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
|
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
|
||||||
"""Create reflective configuration objects from configuration data"""
|
"""Create reflective configuration objects from configuration data"""
|
||||||
|
|
||||||
@@ -368,12 +364,12 @@ class MemoryReflectionService:
|
|||||||
if reflexion_range_value is None or reflexion_range_value == "":
|
if reflexion_range_value is None or reflexion_range_value == "":
|
||||||
reflexion_range_value = "partial"
|
reflexion_range_value = "partial"
|
||||||
reflexion_range = ReflectionRange(reflexion_range_value)
|
reflexion_range = ReflectionRange(reflexion_range_value)
|
||||||
|
|
||||||
baseline_value = config_data.get("baseline")
|
baseline_value = config_data.get("baseline")
|
||||||
if baseline_value is None or baseline_value == "":
|
if baseline_value is None or baseline_value == "":
|
||||||
baseline_value = "TIME"
|
baseline_value = "TIME"
|
||||||
baseline = ReflectionBaseline(baseline_value)
|
baseline = ReflectionBaseline(baseline_value)
|
||||||
|
|
||||||
# iteration_period =
|
# iteration_period =
|
||||||
iteration_period = config_data.get("iteration_period", 24)
|
iteration_period = config_data.get("iteration_period", 24)
|
||||||
if isinstance(iteration_period, str):
|
if isinstance(iteration_period, str):
|
||||||
@@ -381,7 +377,6 @@ class MemoryReflectionService:
|
|||||||
iteration_period = int(iteration_period)
|
iteration_period = int(iteration_period)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
iteration_period = 24 # 默认24小时
|
iteration_period = 24 # 默认24小时
|
||||||
|
|
||||||
return ReflectionConfig(
|
return ReflectionConfig(
|
||||||
enabled=config_data.get("enable_self_reflexion", False),
|
enabled=config_data.get("enable_self_reflexion", False),
|
||||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
if not params.rerank_id:
|
if not params.rerank_id:
|
||||||
params.rerank_id = configs.get('rerank')
|
params.rerank_id = configs.get('rerank')
|
||||||
|
|
||||||
|
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||||
|
if not params.reflection_model_id:
|
||||||
|
params.reflection_model_id = params.llm_id
|
||||||
|
if not params.emotion_model_id:
|
||||||
|
params.emotion_model_id = params.llm_id
|
||||||
|
|
||||||
config = MemoryConfigRepository.create(self.db, params)
|
config = MemoryConfigRepository.create(self.db, params)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return {"affected": 1, "config_id": config.config_id}
|
return {"affected": 1, "config_id": config.config_id}
|
||||||
@@ -177,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# --- Read All ---
|
# --- Read All ---
|
||||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||||
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
|
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||||
|
|
||||||
# 将 ORM 对象转换为字典列表
|
# 将 ORM 对象转换为字典列表
|
||||||
data_list = []
|
data_list = []
|
||||||
for config in configs:
|
for config, scene_name in results:
|
||||||
# 安全地转换 user_id 为 int
|
# 安全地转换 user_id 为 int
|
||||||
config_id_old = None
|
config_id_old = None
|
||||||
if config.config_id_old:
|
if config.config_id_old:
|
||||||
@@ -203,6 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"end_user_id": config.end_user_id,
|
"end_user_id": config.end_user_id,
|
||||||
"config_id_old": config_id_old,
|
"config_id_old": config_id_old,
|
||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
|
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||||
|
"scene_name": scene_name, # 新增:场景名称
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
@@ -628,10 +636,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
|||||||
if m < 1:
|
if m < 1:
|
||||||
latest_relative = "刚刚"
|
latest_relative = "刚刚"
|
||||||
elif m < 60:
|
elif m < 60:
|
||||||
latest_relative = f"{m}分钟前"
|
latest_relative = "一会前"
|
||||||
else:
|
else:
|
||||||
h = int(m // 60)
|
latest_relative = "较早前"
|
||||||
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -347,7 +347,9 @@ class ModelConfigService:
|
|||||||
"is_public": model_data.is_public,
|
"is_public": model_data.is_public,
|
||||||
"is_composite": True
|
"is_composite": True
|
||||||
}
|
}
|
||||||
|
if "load_balance_strategy" in model_data.model_fields_set:
|
||||||
|
model_config_data["load_balance_strategy"] = model_data.load_balance_strategy
|
||||||
|
|
||||||
model = ModelConfigRepository.create(db, model_config_data)
|
model = ModelConfigRepository.create(db, model_config_data)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
||||||
@@ -380,7 +382,7 @@ class ModelConfigService:
|
|||||||
for model_config in api_key.model_configs:
|
for model_config in api_key.model_configs:
|
||||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||||
config_type = model_config.type
|
config_type = model_config.type
|
||||||
request_type = model_data.type
|
request_type = existing_model.type
|
||||||
|
|
||||||
if not (config_type == request_type or
|
if not (config_type == request_type or
|
||||||
(config_type in compatible_types and request_type in compatible_types)):
|
(config_type in compatible_types and request_type in compatible_types)):
|
||||||
@@ -391,12 +393,14 @@ class ModelConfigService:
|
|||||||
|
|
||||||
# 更新基本信息
|
# 更新基本信息
|
||||||
existing_model.name = model_data.name
|
existing_model.name = model_data.name
|
||||||
existing_model.type = model_data.type
|
# existing_model.type = model_data.type
|
||||||
existing_model.logo = model_data.logo
|
existing_model.logo = model_data.logo
|
||||||
existing_model.description = model_data.description
|
existing_model.description = model_data.description
|
||||||
existing_model.config = model_data.config
|
existing_model.config = model_data.config
|
||||||
existing_model.is_active = model_data.is_active
|
existing_model.is_active = model_data.is_active
|
||||||
existing_model.is_public = model_data.is_public
|
existing_model.is_public = model_data.is_public
|
||||||
|
if "load_balance_strategy" in model_data.model_fields_set:
|
||||||
|
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||||
|
|
||||||
# 更新 API Keys 关联
|
# 更新 API Keys 关联
|
||||||
existing_model.api_keys.clear()
|
existing_model.api_keys.clear()
|
||||||
@@ -453,9 +457,11 @@ class ModelApiKeyService:
|
|||||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> List[ModelApiKey]:
|
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[
|
||||||
|
list[Any], list[Any]]:
|
||||||
"""根据provider为多个ModelConfig创建API Key"""
|
"""根据provider为多个ModelConfig创建API Key"""
|
||||||
created_keys = []
|
created_keys = []
|
||||||
|
failed_models = [] # 记录验证失败的模型
|
||||||
|
|
||||||
for model_config_id in data.model_config_ids:
|
for model_config_id in data.model_config_ids:
|
||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
@@ -501,10 +507,9 @@ class ModelApiKeyService:
|
|||||||
test_message="Hello"
|
test_message="Hello"
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
if not validation_result["valid"]:
|
||||||
raise BusinessException(
|
# 记录验证失败的模型,但不抛出异常
|
||||||
f"模型配置验证失败: {validation_result['error']}",
|
failed_models.append(model_name)
|
||||||
BizCode.INVALID_PARAMETER
|
continue
|
||||||
)
|
|
||||||
|
|
||||||
# 创建API Key
|
# 创建API Key
|
||||||
api_key_data = ModelApiKeyCreate(
|
api_key_data = ModelApiKeyCreate(
|
||||||
@@ -526,7 +531,7 @@ class ModelApiKeyService:
|
|||||||
for key in created_keys:
|
for key in created_keys:
|
||||||
db.refresh(key)
|
db.refresh(key)
|
||||||
|
|
||||||
return created_keys
|
return created_keys, failed_models
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||||
@@ -684,6 +689,9 @@ class ModelBaseService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
||||||
|
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
|
||||||
|
if existing:
|
||||||
|
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
|
||||||
model_base = ModelBaseRepository.create(db, data.model_dump())
|
model_base = ModelBaseRepository.create(db, data.model_dump())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(model_base)
|
db.refresh(model_base)
|
||||||
|
|||||||
@@ -280,14 +280,22 @@ class MultiAgentOrchestrator:
|
|||||||
|
|
||||||
# 4. 提取子 Agent 的 conversation_id(用于多轮对话)
|
# 4. 提取子 Agent 的 conversation_id(用于多轮对话)
|
||||||
sub_conversation_id = None
|
sub_conversation_id = None
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
if isinstance(results, dict):
|
if isinstance(results, dict):
|
||||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||||
|
# 提取 token 信息
|
||||||
|
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
|
||||||
|
total_tokens += usage.get("total_tokens", 0)
|
||||||
elif isinstance(results, list) and results:
|
elif isinstance(results, list) and results:
|
||||||
for item in results:
|
for item in results:
|
||||||
if "result" in item:
|
if "result" in item:
|
||||||
sub_conversation_id = item["result"].get("conversation_id")
|
sub_conversation_id = item["result"].get("conversation_id")
|
||||||
if sub_conversation_id:
|
if sub_conversation_id:
|
||||||
break
|
break
|
||||||
|
# 累加每个子 Agent 的 token
|
||||||
|
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
|
||||||
|
total_tokens += usage.get("total_tokens", 0)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"多 Agent 任务完成",
|
"多 Agent 任务完成",
|
||||||
@@ -301,9 +309,15 @@ class MultiAgentOrchestrator:
|
|||||||
return {
|
return {
|
||||||
"message": final_result,
|
"message": final_result,
|
||||||
"conversation_id": sub_conversation_id,
|
"conversation_id": sub_conversation_id,
|
||||||
|
"mode": OrchestrationMode.SUPERVISOR,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"strategy": routing_decision.get("collaboration_strategy", "single"),
|
"strategy": routing_decision.get("collaboration_strategy", "single"),
|
||||||
"sub_results": results
|
"sub_results": results,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator:
|
|||||||
return {
|
return {
|
||||||
"message": result.get("response", ""),
|
"message": result.get("response", ""),
|
||||||
"conversation_id": result.get("conversation_id"),
|
"conversation_id": result.get("conversation_id"),
|
||||||
|
"mode": OrchestrationMode.COLLABORATION,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"strategy": "collaboration",
|
"strategy": "collaboration",
|
||||||
"active_agent": result.get("active_agent"),
|
"active_agent": result.get("active_agent"),
|
||||||
"sub_results": result
|
"sub_results": result,
|
||||||
|
"usage": result.get("usage")
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""多 Agent 配置管理服务"""
|
"""多 Agent 配置管理服务"""
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
from typing import Optional, List, Tuple, Any, Annotated
|
from typing import Optional, List, Tuple, Any, Annotated
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@@ -427,6 +428,23 @@ class MultiAgentService:
|
|||||||
memory=getattr(request, 'memory', True) # 记忆功能参数
|
memory=getattr(request, 'memory', True) # 记忆功能参数
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self._save_conversation_message(
|
||||||
|
conversation_id=request.conversation_id,
|
||||||
|
user_message=request.message,
|
||||||
|
assistant_message=result.get("message", ""),
|
||||||
|
app_id=app_id,
|
||||||
|
user_id=request.user_id,
|
||||||
|
meta_data={
|
||||||
|
"mode": result.get("mode"),
|
||||||
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
|
"usage": result.get("usage", {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
@@ -451,11 +469,14 @@ class MultiAgentService:
|
|||||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||||
|
|
||||||
if not config.is_active:
|
if not config.is_active:
|
||||||
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
|
raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
full_content = ""
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
# 3. 流式执行任务
|
# 3. 流式执行任务
|
||||||
async for event in orchestrator.execute_stream(
|
async for event in orchestrator.execute_stream(
|
||||||
message=request.message,
|
message=request.message,
|
||||||
@@ -468,7 +489,88 @@ class MultiAgentService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
yield event
|
if "sub_usage" in event:
|
||||||
|
if "data:" in event:
|
||||||
|
try:
|
||||||
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
|
data = json.loads(data_line)
|
||||||
|
if "total_tokens" in data:
|
||||||
|
total_tokens += data["total_tokens"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
yield event
|
||||||
|
if "data:" in event:
|
||||||
|
try:
|
||||||
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
|
data = json.loads(data_line)
|
||||||
|
if "content" in data:
|
||||||
|
full_content += data["content"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self._save_conversation_message(
|
||||||
|
conversation_id=request.conversation_id,
|
||||||
|
user_message=request.message,
|
||||||
|
assistant_message=full_content,
|
||||||
|
app_id=app_id,
|
||||||
|
user_id=request.user_id,
|
||||||
|
meta_data={
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _save_conversation_message(
|
||||||
|
self,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
meta_data: dict,
|
||||||
|
app_id: Optional[uuid.UUID] = None,
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""保存会话消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话ID
|
||||||
|
user_message: 用户消息
|
||||||
|
assistant_message: AI 回复消息
|
||||||
|
meta_data: 元数据(包括 token 消耗)
|
||||||
|
app_id: 应用ID
|
||||||
|
user_id: 用户ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
|
conversation_service = ConversationService(self.db)
|
||||||
|
|
||||||
|
conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=user_message
|
||||||
|
)
|
||||||
|
conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=assistant_message,
|
||||||
|
meta_data=meta_data
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"保存多 Agent 会话消息",
|
||||||
|
extra={
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user_message_length": len(user_message),
|
||||||
|
"assistant_message_length": len(assistant_message)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("保存会话消息失败", extra={"error": str(e)})
|
||||||
|
|
||||||
# def add_sub_agent(
|
# def add_sub_agent(
|
||||||
# self,
|
# self,
|
||||||
|
|||||||
1162
api/app/services/ontology_service.py
Normal file
1162
api/app/services/ontology_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
@@ -16,9 +17,10 @@ from app.models.prompt_optimizer_model import (
|
|||||||
PromptOptimizerSession,
|
PromptOptimizerSession,
|
||||||
RoleType
|
RoleType
|
||||||
)
|
)
|
||||||
from app.repositories.model_repository import ModelConfigRepository
|
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||||
from app.repositories.prompt_optimizer_repository import (
|
from app.repositories.prompt_optimizer_repository import (
|
||||||
PromptOptimizerSessionRepository
|
PromptOptimizerSessionRepository,
|
||||||
|
PromptReleaseRepository
|
||||||
)
|
)
|
||||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||||
|
|
||||||
@@ -28,6 +30,8 @@ logger = get_business_logger()
|
|||||||
class PromptOptimizerService:
|
class PromptOptimizerService:
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.optim_repo = PromptOptimizerSessionRepository(self.db)
|
||||||
|
self.release_repo = PromptReleaseRepository(self.db)
|
||||||
|
|
||||||
def get_model_config(
|
def get_model_config(
|
||||||
self,
|
self,
|
||||||
@@ -78,10 +82,12 @@ class PromptOptimizerService:
|
|||||||
Returns:
|
Returns:
|
||||||
PromptOptimzerSession: The newly created prompt optimization session.
|
PromptOptimzerSession: The newly created prompt optimization session.
|
||||||
"""
|
"""
|
||||||
session = PromptOptimizerSessionRepository(self.db).create_session(
|
session = self.optim_repo.create_session(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(session)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_session_message_history(
|
def get_session_message_history(
|
||||||
@@ -106,7 +112,7 @@ class PromptOptimizerService:
|
|||||||
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
|
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
|
||||||
- content (str): The content of the message.
|
- content (str): The content of the message.
|
||||||
"""
|
"""
|
||||||
history = PromptOptimizerSessionRepository(self.db).get_session_history(
|
history = self.optim_repo.get_session_history(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
@@ -168,7 +174,8 @@ class PromptOptimizerService:
|
|||||||
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||||
|
|
||||||
# Create LLM instance
|
# Create LLM instance
|
||||||
api_config: ModelApiKey = model_config.api_keys[0]
|
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
||||||
|
api_config: ModelApiKey = api_keys[0] if api_keys else None
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
llm = RedBearLLM(RedBearModelConfig(
|
||||||
model_name=api_config.model_name,
|
model_name=api_config.model_name,
|
||||||
provider=api_config.provider,
|
provider=api_config.provider,
|
||||||
@@ -176,11 +183,12 @@ class PromptOptimizerService:
|
|||||||
base_url=api_config.api_base
|
base_url=api_config.api_base
|
||||||
), type=ModelType(model_config.type))
|
), type=ModelType(model_config.type))
|
||||||
try:
|
try:
|
||||||
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
|
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render()
|
rendered_system_message = Template(opt_system_prompt).render()
|
||||||
|
|
||||||
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_user_prompt = f.read()
|
opt_user_prompt = f.read()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||||
@@ -295,4 +303,165 @@ class PromptOptimizerService:
|
|||||||
role=role,
|
role=role,
|
||||||
content=content
|
content=content
|
||||||
)
|
)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(message)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def save_prompt(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
session_id: uuid.UUID,
|
||||||
|
title: str,
|
||||||
|
prompt: str
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Create and save a new prompt release for a given session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
|
||||||
|
session_id (uuid.UUID): The ID of the session to associate with this prompt.
|
||||||
|
title (str): The title of the prompt release.
|
||||||
|
prompt (str): The content of the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing:
|
||||||
|
- id (UUID): The unique ID of the created prompt release.
|
||||||
|
- session_id (UUID): The session ID linked to the release.
|
||||||
|
- title (str): The title of the prompt.
|
||||||
|
- prompt (str): The prompt content.
|
||||||
|
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If a prompt release already exists for the given session.
|
||||||
|
"""
|
||||||
|
session = self.optim_repo.get_session_by_id(session_id)
|
||||||
|
if session is None or session.tenant_id != tenant_id:
|
||||||
|
raise BusinessException(
|
||||||
|
"Session does not exist or the current user has no access",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.release_repo.get_prompt_by_session_id(session_id):
|
||||||
|
raise BusinessException(
|
||||||
|
"A release already exists for the current session",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_obj = self.release_repo.create_prompt_release(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
title=title,
|
||||||
|
session_id=session_id,
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(prompt_obj)
|
||||||
|
return {
|
||||||
|
"id": prompt_obj.id,
|
||||||
|
"session_id": prompt_obj.session_id,
|
||||||
|
"title": prompt_obj.title,
|
||||||
|
"prompt": prompt_obj.prompt,
|
||||||
|
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
def delete_prompt(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
prompt_id: uuid.UUID
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Soft delete a prompt release by prompt_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): Tenant identifier.
|
||||||
|
prompt_id (uuid.UUID): Prompt identifier.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If the prompt does not exist or already deleted.
|
||||||
|
"""
|
||||||
|
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
|
||||||
|
if not prompt_obj or prompt_obj.is_delete:
|
||||||
|
raise BusinessException(
|
||||||
|
"Prompt does not exist or has already been deleted",
|
||||||
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_obj.tenant_id != tenant_id:
|
||||||
|
raise BusinessException(
|
||||||
|
"No permission to delete this prompt",
|
||||||
|
BizCode.FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
|
self.release_repo.soft_delete_prompt(prompt_obj)
|
||||||
|
self.db.commit()
|
||||||
|
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
|
||||||
|
|
||||||
|
def get_release_list(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
filter_keyword: str | None = None
|
||||||
|
) -> dict[str, int | list[Any]]:
|
||||||
|
"""
|
||||||
|
Get paginated list of prompt releases with optional filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id (uuid.UUID): Tenant identifier.
|
||||||
|
page (int): Page number (starting from 1).
|
||||||
|
page_size (int): Number of items per page.
|
||||||
|
filter_keyword (str | None): Optional keyword to filter by title.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains total count, pagination info, and list of releases.
|
||||||
|
"""
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# Get total count and releases based on filter
|
||||||
|
if filter_keyword:
|
||||||
|
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
|
||||||
|
releases = self.release_repo.search_prompts_paginated(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
keyword=filter_keyword,
|
||||||
|
offset=offset,
|
||||||
|
limit=page_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total = self.release_repo.count_prompts(tenant_id)
|
||||||
|
releases = self.release_repo.get_prompts_paginated(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
offset=offset,
|
||||||
|
limit=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for release in releases:
|
||||||
|
# Get first user message from session
|
||||||
|
first_message = self.optim_repo.get_first_user_message(
|
||||||
|
session_id=release.session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
items.append({
|
||||||
|
"id": release.id,
|
||||||
|
"title": release.title,
|
||||||
|
"prompt": release.prompt,
|
||||||
|
"created_at": int(release.created_at.timestamp() * 1000),
|
||||||
|
"first_message": first_message
|
||||||
|
})
|
||||||
|
|
||||||
|
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
|
||||||
|
if filter_keyword:
|
||||||
|
log_msg += f", filter='{filter_keyword}'"
|
||||||
|
logger.info(log_msg)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"page": {
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"hasnext": page * page_size < total
|
||||||
|
},
|
||||||
|
"keyword": filter_keyword,
|
||||||
|
"items": items
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user