Compare commits
1 Commits
release/v0
...
v0.1.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b13b4a949 |
@@ -1,25 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import TypeDecorator, JSON
|
||||
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSON
|
||||
|
||||
def __init__(self, pydantic_model: type[BaseModel]):
|
||||
super().__init__()
|
||||
self.model = pydantic_model
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# 入库:Model -> dict
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.model):
|
||||
return value.dict()
|
||||
return value # 已经是 dict 也放行
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
# 出库:dict -> Model
|
||||
if value is None:
|
||||
return None
|
||||
# return self.model.parse_obj(value) # pydantic v1
|
||||
return self.model.model_validate(value) # pydantic v2
|
||||
@@ -85,8 +85,6 @@ health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
|
||||
@@ -105,13 +103,6 @@ beat_schedule_config = {
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
|
||||
@@ -4,47 +4,37 @@
|
||||
认证方式: JWT Token
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import (
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
file_controller,
|
||||
home_page_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
prompt_optimizer_controller,
|
||||
public_share_controller,
|
||||
release_share_controller,
|
||||
setup_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
tool_controller,
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
@@ -69,8 +59,6 @@ manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(memory_storage_controller.router)
|
||||
manager_router.include_router(user_memory_controllers.router)
|
||||
manager_router.include_router(memory_episodic_controller.router)
|
||||
manager_router.include_router(memory_explicit_controller.router)
|
||||
manager_router.include_router(api_key_controller.router)
|
||||
manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
@@ -81,12 +69,6 @@ manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(memory_short_term_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(memory_forget_controller.router)
|
||||
manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -11,16 +11,15 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
from app.models.app_model import AppType
|
||||
from app.models.app_model import AppType, App
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
@@ -30,9 +29,9 @@ logger = get_business_logger()
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
|
||||
@@ -42,34 +41,22 @@ def create_app(
|
||||
@router.get("", summary="应用列表(分页)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_apps(
|
||||
type: str | None = None,
|
||||
visibility: str | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
type: str | None = None,
|
||||
visibility: str | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用
|
||||
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -82,17 +69,18 @@ def list_apps(
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||
service = app_service.AppService(db)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用详细信息
|
||||
|
||||
@@ -111,10 +99,10 @@ def get_app(
|
||||
@router.put("/{app_id}", summary="更新应用基本信息")
|
||||
@cur_workspace_access_guard()
|
||||
def update_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
@@ -124,9 +112,9 @@ def update_app(
|
||||
@router.delete("/{app_id}", summary="删除应用")
|
||||
@cur_workspace_access_guard()
|
||||
def delete_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""删除应用
|
||||
|
||||
@@ -153,10 +141,10 @@ def delete_app(
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""复制应用(包括基础信息和配置)
|
||||
|
||||
@@ -190,10 +178,10 @@ def copy_app(
|
||||
@router.put("/{app_id}/config", summary="更新 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def update_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AgentConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AgentConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
@@ -204,9 +192,9 @@ def update_agent_config(
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -218,10 +206,10 @@ def get_agent_config(
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.PublishRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.PublishRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.publish(
|
||||
@@ -229,7 +217,7 @@ def publish_app(
|
||||
app_id=app_id,
|
||||
publisher_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
version_name=payload.version_name,
|
||||
version_name = payload.version_name,
|
||||
release_notes=payload.release_notes
|
||||
)
|
||||
return success(data=app_schema.AppRelease.model_validate(release))
|
||||
@@ -238,9 +226,9 @@ def publish_app(
|
||||
@router.get("/{app_id}/release", summary="获取当前发布版本")
|
||||
@cur_workspace_access_guard()
|
||||
def get_current_release(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -252,9 +240,9 @@ def get_current_release(
|
||||
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_releases(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
|
||||
@@ -265,10 +253,10 @@ def list_releases(
|
||||
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
|
||||
@cur_workspace_access_guard()
|
||||
def rollback(
|
||||
app_id: uuid.UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
|
||||
@@ -278,10 +266,10 @@ def rollback(
|
||||
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
|
||||
@cur_workspace_access_guard()
|
||||
def share_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
@@ -306,10 +294,10 @@ def share_app(
|
||||
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_app(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""取消应用分享
|
||||
|
||||
@@ -330,9 +318,9 @@ def unshare_app(
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用的所有分享记录
|
||||
|
||||
@@ -349,15 +337,14 @@ def list_app_shares(
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||
):
|
||||
"""
|
||||
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||
@@ -374,7 +361,7 @@ async def draft_run(
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None:
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
@@ -384,9 +371,10 @@ async def draft_run(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
@@ -406,22 +394,13 @@ async def draft_run(
|
||||
# 只读操作,允许访问共享应用
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 处理会话ID(创建或验证)
|
||||
conversation_id = await draft_service._ensure_conversation(
|
||||
conversation_id=payload.conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=payload.user_id
|
||||
)
|
||||
conversation_id=payload.conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=payload.user_id
|
||||
)
|
||||
payload.conversation_id = conversation_id
|
||||
|
||||
if app.type == AppType.AGENT:
|
||||
@@ -445,16 +424,17 @@ async def draft_run(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -526,7 +506,7 @@ async def draft_run(
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=payload.message,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables or {},
|
||||
use_llm_routing=True # 默认启用 LLM 路由
|
||||
)
|
||||
@@ -548,10 +528,10 @@ async def draft_run(
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
):
|
||||
yield event
|
||||
@@ -591,7 +571,7 @@ async def draft_run(
|
||||
data=result,
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
elif app.type == AppType.WORKFLOW: #工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
@@ -612,18 +592,17 @@ async def draft_run(
|
||||
data: <json_data>
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
# 调用工作流服务的流式方法
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
config=config
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
@@ -648,7 +627,7 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -663,13 +642,14 @@ async def draft_run(
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run_compare(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunCompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunCompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
多模型对比试运行
|
||||
@@ -694,7 +674,7 @@ async def draft_run_compare(
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None:
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
@@ -703,7 +683,7 @@ async def draft_run_compare(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
logger.info(
|
||||
@@ -748,23 +728,9 @@ async def draft_run_compare(
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
agent_model_params = agent_cfg.model_parameters
|
||||
if hasattr(agent_model_params, 'model_dump'):
|
||||
agent_model_params = agent_model_params.model_dump()
|
||||
elif not isinstance(agent_model_params, dict):
|
||||
agent_model_params = {}
|
||||
|
||||
# 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
item_model_params = model_item.model_parameters
|
||||
if hasattr(item_model_params, 'model_dump'):
|
||||
item_model_params = item_model_params.model_dump()
|
||||
elif not isinstance(item_model_params, dict):
|
||||
item_model_params = {}
|
||||
|
||||
merged_parameters = {
|
||||
**(agent_model_params or {}),
|
||||
**(item_model_params or {})
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
@@ -781,19 +747,19 @@ async def draft_run_compare(
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -855,15 +821,15 @@ async def get_workflow_config(
|
||||
# 配置总是存在(不存在时返回默认模板)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
||||
@cur_workspace_access_guard()
|
||||
async def update_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: WorkflowConfigUpdate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
app_id: uuid.UUID,
|
||||
payload: WorkflowConfigUpdate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.models.document_model import Document
|
||||
from app.models.user_model import User
|
||||
from app.models.document_model import Document
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -145,7 +141,7 @@ async def get_preview_chunks(
|
||||
}
|
||||
}
|
||||
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
|
||||
return success(data=jsonable_encoder(result), msg="Querying the document block preview list succeeded")
|
||||
return success(data=result, msg="Querying the document block preview list succeeded")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||
@@ -203,7 +199,7 @@ async def get_chunks(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of document chunk list succeeded")
|
||||
return success(data=result, msg="Query of document chunk list succeeded")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||
@@ -264,7 +260,7 @@ async def create_chunk(
|
||||
db_document.chunk_num += 1
|
||||
db.commit()
|
||||
|
||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||
return success(data=chunk, msg="Document chunk creation successful")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@@ -291,7 +287,7 @@ async def get_chunk(
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
return success(data=jsonable_encoder(items[0]), msg="Document chunk query successful")
|
||||
return success(data=items[0], msg="Document chunk query successful")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -328,7 +324,7 @@ async def update_chunk(
|
||||
chunk = items[0]
|
||||
chunk.page_content = content
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -393,41 +389,36 @@ async def retrieve_chunks(
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
private_items = knowledge_service.get_chunked_knowledgeids(
|
||||
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
private_kb_ids = [item[0] for item in private_items]
|
||||
private_workspace_ids = [item[1] for item in private_items]
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
items = knowledge_service.get_chunked_knowledgeids(
|
||||
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
if items:
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||
]
|
||||
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
share_kb_ids = [item[0] for item in share_items]
|
||||
share_workspace_ids = [item[1] for item in share_items]
|
||||
private_kb_ids.extend(share_kb_ids)
|
||||
private_workspace_ids.extend(share_workspace_ids)
|
||||
if not private_kb_ids:
|
||||
existing_ids.extend(items)
|
||||
if not existing_ids:
|
||||
return success(data=[], msg="retrieval successful")
|
||||
kb_id = private_kb_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -457,21 +448,4 @@ async def retrieve_chunks(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
@@ -1,26 +1,23 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.controllers import file_controller
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import document_model
|
||||
from app.models.user_model import User
|
||||
from app.models import document_model
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
|
||||
from app.controllers import file_controller
|
||||
from app.celery_app import celery_app
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -109,7 +106,7 @@ async def get_documents(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of document list succeeded")
|
||||
return success(data=result, msg="Query of document list succeeded")
|
||||
|
||||
|
||||
@router.post("/document", response_model=ApiResponse)
|
||||
@@ -127,7 +124,7 @@ async def create_document(
|
||||
api_logger.debug(f"Start creating a document: {create_data.file_name}")
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document creation successful")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
|
||||
raise
|
||||
@@ -156,7 +153,7 @@ async def get_document(
|
||||
)
|
||||
|
||||
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Successfully obtained document information")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -224,7 +221,7 @@ async def update_document(
|
||||
)
|
||||
|
||||
# 5. Return the updated document
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document information updated successfully")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -18,7 +18,6 @@ from app.models.user_model import User
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest,
|
||||
EmotionGenerateSuggestionsRequest,
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
)
|
||||
@@ -31,7 +30,7 @@ from sqlalchemy.orm import Session
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion-memory",
|
||||
prefix="/memory/emotion",
|
||||
tags=["Emotion Analysis"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
@@ -199,7 +198,7 @@ async def get_emotion_suggestions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
@@ -207,72 +206,7 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
||||
extra={"group_id": request.group_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议",
|
||||
None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取个性化建议失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id、可选的 config_id 和 force_refresh
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
新生成的个性化情绪建议响应
|
||||
个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
@@ -300,44 +234,36 @@ async def generate_emotion_suggestions(
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层生成建议
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.group_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议生成成功",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议生成成功")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"生成个性化建议失败: {str(e)}",
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
detail=f"获取个性化建议失败: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import file_model
|
||||
from app.models.user_model import User
|
||||
from app.models import file_model
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import file_service, document_service
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -96,11 +93,11 @@ async def get_files(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||
return success(data=result, msg="Query of file list succeeded")
|
||||
|
||||
|
||||
@router.post("/folder", response_model=ApiResponse)
|
||||
async def create_folder(
|
||||
def create_folder(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
folder_name: str = '/',
|
||||
@@ -124,7 +121,7 @@ async def create_folder(
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||
raise
|
||||
@@ -210,7 +207,7 @@ async def upload_file(
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="File upload successful")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
|
||||
|
||||
|
||||
@router.post("/customtext", response_model=ApiResponse)
|
||||
@@ -291,7 +288,7 @@ async def custom_text(
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
@@ -365,7 +362,7 @@ async def update_file(
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_data.dict(exclude_unset=True).items():
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_file, field):
|
||||
old_value = getattr(db_file, field)
|
||||
if old_value != value:
|
||||
@@ -390,7 +387,7 @@ async def update_file(
|
||||
)
|
||||
|
||||
# 4. Return the updated file
|
||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
router = APIRouter(prefix="/home-page", tags=["Home Page"])
|
||||
|
||||
@router.get("/statistics", response_model=ApiResponse)
|
||||
def get_home_statistics(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取首页统计数据"""
|
||||
statistics = HomePageService.get_home_statistics(db, current_user.tenant_id)
|
||||
return success(data=statistics, msg="统计数据获取成功")
|
||||
|
||||
@router.get("/workspaces", response_model=ApiResponse)
|
||||
def get_workspace_list(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工作空间列表"""
|
||||
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
"introduction": version_info.get("introduction"),
|
||||
"introduction_en": version_info.get("introduction_en")
|
||||
},
|
||||
msg="系统版本获取成功"
|
||||
)
|
||||
@@ -1,431 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import (
|
||||
cur_workspace_access_guard,
|
||||
get_current_user,
|
||||
)
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.implicit_memory_schema import GenerateProfileRequest
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/implicit-memory",
|
||||
tags=["Implicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
|
||||
"""
|
||||
Centralized error handling for implicit memory operations.
|
||||
|
||||
Args:
|
||||
e: The exception that occurred
|
||||
operation: Description of the operation that failed
|
||||
user_id: Optional user ID for logging context
|
||||
|
||||
Returns:
|
||||
Standardized error response
|
||||
"""
|
||||
error_context = f"user_id={user_id}" if user_id else "unknown user"
|
||||
|
||||
if isinstance(e, ValueError):
|
||||
if "user" in str(e).lower() and "not found" in str(e).lower():
|
||||
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
|
||||
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
|
||||
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
|
||||
else:
|
||||
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
|
||||
|
||||
elif isinstance(e, KeyError):
|
||||
api_logger.warning(f"Missing required data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
|
||||
|
||||
elif isinstance(e, (ConnectionError, TimeoutError)):
|
||||
api_logger.error(f"Service unavailable for {operation}: {error_context}")
|
||||
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
|
||||
|
||||
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
|
||||
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
|
||||
|
||||
elif "storage" in str(e).lower() or "database" in str(e).lower():
|
||||
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
|
||||
|
||||
else:
|
||||
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
|
||||
|
||||
|
||||
def validate_user_id(user_id: str) -> None:
|
||||
"""
|
||||
Validate user ID format and constraints.
|
||||
|
||||
Args:
|
||||
user_id: User ID to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If user ID is invalid
|
||||
"""
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValueError("User ID cannot be empty")
|
||||
|
||||
if len(user_id.strip()) < 1:
|
||||
raise ValueError("User ID is too short")
|
||||
|
||||
|
||||
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
|
||||
"""
|
||||
Validate date range parameters.
|
||||
|
||||
Args:
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Raises:
|
||||
ValueError: If date range is invalid
|
||||
"""
|
||||
if (start_date and not end_date) or (end_date and not start_date):
|
||||
raise ValueError("Both start_date and end_date must be provided together")
|
||||
|
||||
if start_date and end_date and start_date >= end_date:
|
||||
raise ValueError("start_date must be before end_date")
|
||||
|
||||
if start_date and start_date > datetime.now():
|
||||
raise ValueError("start_date cannot be in the future")
|
||||
|
||||
|
||||
def validate_confidence_threshold(threshold: float) -> None:
|
||||
"""
|
||||
Validate confidence threshold parameter.
|
||||
|
||||
Args:
|
||||
threshold: Confidence threshold to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If threshold is invalid
|
||||
"""
|
||||
if not 0.0 <= threshold <= 1.0:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter end date"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
)
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
|
||||
# Apply filters (client-side filtering on cached data)
|
||||
filtered_preferences = []
|
||||
for pref in preferences:
|
||||
# Filter by confidence threshold
|
||||
if confidence_threshold is not None and pref.get("confidence_score", 0) < confidence_threshold:
|
||||
continue
|
||||
|
||||
# Filter by category if specified
|
||||
if tag_category and pref.get("category") != tag_category:
|
||||
continue
|
||||
|
||||
# Filter by date range if specified
|
||||
if start_date or end_date:
|
||||
created_at_ts = pref.get("created_at")
|
||||
if created_at_ts:
|
||||
created_at = datetime.fromtimestamp(created_at_ts / 1000)
|
||||
if start_date and created_at < start_date:
|
||||
continue
|
||||
if end_date and created_at > end_date:
|
||||
continue
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
)
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
)
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
)
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
# Apply filters (client-side filtering on cached data)
|
||||
filtered_habits = []
|
||||
for habit in habits:
|
||||
# Filter by confidence level
|
||||
if confidence_level:
|
||||
confidence_mapping = {
|
||||
"high": 85,
|
||||
"medium": 50,
|
||||
"low": 20
|
||||
}
|
||||
numerical_confidence = confidence_mapping.get(confidence_level.lower())
|
||||
if habit.get("confidence_level", 0) < numerical_confidence:
|
||||
continue
|
||||
|
||||
# Filter by frequency pattern
|
||||
if frequency_pattern and habit.get("frequency_pattern") != frequency_pattern:
|
||||
continue
|
||||
|
||||
# Filter by time period
|
||||
if time_period:
|
||||
is_current = habit.get("is_current", True)
|
||||
if time_period.lower() == "current" and not is_current:
|
||||
continue
|
||||
elif time_period.lower() == "past" and is_current:
|
||||
continue
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/generate_profile", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def generate_implicit_memory_profile(
|
||||
request: GenerateProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Generate complete user profile (all 4 modules) and cache it.
|
||||
|
||||
Args:
|
||||
request: Generate profile request with end_user_id
|
||||
db: Database session
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Complete user profile with all modules
|
||||
"""
|
||||
end_user_id = request.end_user_id
|
||||
api_logger.info(f"Generate profile requested for user: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Generate complete profile (calls LLM for all 4 modules)
|
||||
api_logger.info(f"开始生成完整用户画像: user={end_user_id}")
|
||||
profile_data = await service.generate_complete_profile(user_id=end_user_id)
|
||||
|
||||
# Save to cache
|
||||
await service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db,
|
||||
expires_hours=168 # 7 days
|
||||
)
|
||||
|
||||
api_logger.info(f"用户画像生成并缓存成功: user={end_user_id}")
|
||||
|
||||
# Add metadata
|
||||
profile_data["end_user_id"] = end_user_id
|
||||
profile_data["cached"] = False
|
||||
|
||||
return success(data=profile_data, msg="用户画像生成成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"生成用户画像失败: user={end_user_id}, error={str(e)}", exc_info=True)
|
||||
return handle_implicit_memory_error(e, "用户画像生成", end_user_id)
|
||||
@@ -1,29 +1,26 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common import settings
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model
|
||||
from app.models.user_model import User
|
||||
from app.models import knowledge_model, document_model, file_model
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.common import settings
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -50,45 +47,6 @@ def get_parser_types():
|
||||
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||
|
||||
|
||||
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
async def get_knowledge_graph_entity_types(
|
||||
llm_id: uuid.UUID,
|
||||
scenario: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on llm_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge graph: llm_id={llm_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the model exists
|
||||
api_logger.debug(f"Check whether the model exists: {llm_id}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=llm_id)
|
||||
|
||||
if not config:
|
||||
api_logger.warning(
|
||||
f"The model does not exist or you do not have permission to access it: llm_id={llm_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The model does not exist or you do not have permission to access it"
|
||||
)
|
||||
# 2. Prepare to configure chat_mdl information
|
||||
chat_model = Base(
|
||||
key=config.api_keys[0].api_key,
|
||||
model_name=config.api_keys[0].model_name,
|
||||
base_url=config.api_keys[0].api_base
|
||||
)
|
||||
response = graph_entity_types(chat_model, scenario)
|
||||
return success(data=response, msg="Successfully obtained knowledge graph entity types")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"get knowledge graph entity types failed: llm_id={llm_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/knowledges", response_model=ApiResponse)
|
||||
async def get_knowledges(
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||
@@ -172,7 +130,7 @@ async def get_knowledges(
|
||||
"has_next": True if page*pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of knowledge base list successful")
|
||||
return success(data=result, msg="Query of knowledge base list successful")
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@@ -198,7 +156,7 @@ async def create_knowledge(
|
||||
)
|
||||
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
|
||||
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base has been successfully created")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
@@ -227,7 +185,7 @@ async def get_knowledge(
|
||||
)
|
||||
|
||||
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="Successfully obtained knowledge base information")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -244,7 +202,7 @@ async def update_knowledge(
|
||||
):
|
||||
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
|
||||
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base information has been successfully updated")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
|
||||
|
||||
|
||||
async def _update_knowledge(
|
||||
@@ -421,7 +379,7 @@ async def delete_knowledge_graph(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete knowledge graph
|
||||
Soft-delete knowledge graph
|
||||
"""
|
||||
api_logger.info(f"Request to delete knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
@@ -484,3 +442,42 @@ async def rebuild_knowledge_graph(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
async def get_knowledge_graph_entity_types(
|
||||
knowledge_id: uuid.UUID,
|
||||
scenario: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the knowledge base exists
|
||||
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
|
||||
if not db_knowledge:
|
||||
api_logger.warning(
|
||||
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||
)
|
||||
# 2. Prepare to configure chat_mdl information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
response = graph_entity_types(chat_model, scenario)
|
||||
return success(data=response, msg="Successfully obtained knowledge graph entity types")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"get knowledge graph entity types failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_agent_schema import End_User_Information
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -99,8 +102,7 @@ async def get_workspace_end_users(
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
||||
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
@@ -111,17 +113,6 @@ async def get_workspace_end_users(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
memory_configs_map = {}
|
||||
if end_user_ids:
|
||||
try:
|
||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
# 失败时使用空字典,不影响其他数据返回
|
||||
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
memory_num = {}
|
||||
@@ -132,25 +123,10 @@ async def get_workspace_end_users(
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
|
||||
# 从批量查询结果中获取配置信息
|
||||
user_id = str(end_user.id)
|
||||
memory_config_info = memory_configs_map.get(user_id, {
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None
|
||||
})
|
||||
|
||||
# 只保留需要的字段,移除 error 字段(如果有)
|
||||
memory_config = {
|
||||
"memory_config_id": memory_config_info.get("memory_config_id"),
|
||||
"memory_config_name": memory_config_info.get("memory_config_name")
|
||||
}
|
||||
|
||||
result.append(
|
||||
{
|
||||
'end_user': end_user,
|
||||
'memory_num': memory_num,
|
||||
'memory_config': memory_config
|
||||
'end_user':end_user,
|
||||
'memory_num':memory_num
|
||||
}
|
||||
)
|
||||
|
||||
@@ -489,6 +465,7 @@ async def dashboard_data(
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
|
||||
user_rag_memory_id = None
|
||||
|
||||
# 根据 storage_type 决定返回哪个数据对象
|
||||
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
情景记忆相关的控制器
|
||||
包含情景记忆总览和详情查询接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_episodic_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
)
|
||||
from app.services.memory_episodic_service import memory_episodic_service
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/episodic-memory",
|
||||
tags=["Episodic Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/overview", response_model=ApiResponse)
|
||||
async def get_episodic_memory_overview_api(
|
||||
request: EpisodicMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆总览
|
||||
|
||||
返回指定用户的所有情景记忆列表,包括标题和创建时间。
|
||||
支持通过时间范围、情景类型和标题关键词进行筛选。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证参数
|
||||
valid_time_ranges = ["all", "today", "this_week", "this_month"]
|
||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||
|
||||
if request.time_range not in valid_time_ranges:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}")
|
||||
|
||||
if request.episodic_type not in valid_episodic_types:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||
|
||||
# 处理 title_keyword(去除首尾空格)
|
||||
title_keyword = request.title_keyword.strip() if request.title_keyword else None
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, "
|
||||
f"title_keyword={title_keyword}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_episodic_service.get_episodic_memory_overview(
|
||||
request.end_user_id, request.time_range, request.episodic_type, title_keyword
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆详情
|
||||
|
||||
返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_episodic_service.get_episodic_memory_details(
|
||||
end_user_id=request.end_user_id,
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理情景记忆不存在的情况
|
||||
api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e))
|
||||
@@ -1,115 +0,0 @@
|
||||
"""
|
||||
显性记忆控制器
|
||||
|
||||
处理显性记忆相关的API接口,包括情景记忆和语义记忆的查询。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.memory_explicit_service import MemoryExplicitService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_explicit_schema import (
|
||||
ExplicitMemoryOverviewRequest,
|
||||
ExplicitMemoryDetailsRequest,
|
||||
)
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_explicit_service = MemoryExplicitService()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/explicit-memory",
|
||||
tags=["Explicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/overview", response_model=ApiResponse)
|
||||
async def get_explicit_memory_overview_api(
|
||||
request: ExplicitMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆总览
|
||||
|
||||
返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_explicit_service.get_explicit_memory_overview(
|
||||
request.end_user_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_explicit_memory_details_api(
|
||||
request: ExplicitMemoryDetailsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆详情
|
||||
|
||||
根据 memory_id 返回情景记忆或语义记忆的详细信息。
|
||||
- 情景记忆:包括标题、内容、情绪、创建时间
|
||||
- 语义记忆:包括名称、核心定义、详细笔记、创建时间
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await memory_explicit_service.get_explicit_memory_details(
|
||||
end_user_id=request.end_user_id,
|
||||
memory_id=request.memory_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"memory_type={result.get('memory_type')}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理记忆不存在的情况
|
||||
api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e))
|
||||
@@ -1,363 +0,0 @@
|
||||
"""
|
||||
遗忘引擎控制器模块
|
||||
|
||||
本模块提供遗忘引擎的 REST API 接口,包括:
|
||||
1. 手动触发遗忘周期
|
||||
2. 获取和更新配置
|
||||
3. 获取统计信息
|
||||
4. 获取遗忘曲线数据
|
||||
|
||||
所有接口都需要用户认证,并自动关联到当前工作空间。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ForgettingTriggerRequest,
|
||||
ForgettingConfigResponse,
|
||||
ForgettingConfigUpdateRequest,
|
||||
ForgettingStatsResponse,
|
||||
ForgettingReportResponse,
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/forget-memory",
|
||||
tags=["Memory Forgetting Engine"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/trigger", response_model=ApiResponse)
|
||||
async def trigger_forgetting_cycle(
|
||||
payload: ForgettingTriggerRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
|
||||
执行一次完整的遗忘周期,识别并融合低激活值节点。
|
||||
|
||||
Args:
|
||||
payload: 触发请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘报告的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = payload.end_user_id # 从 payload 中获取 end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
|
||||
f"end_user_id={end_user_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"min_days={payload.min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingReportResponse(**report)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="遗忘周期执行成功")
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"触发遗忘周期失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
|
||||
读取指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含配置信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"读取遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse)
|
||||
async def update_forgetting_config(
|
||||
payload: ForgettingConfigUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新遗忘引擎配置
|
||||
|
||||
更新指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
payload: 配置更新请求
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 构建更新字段字典(排除 None 值和 config_id)
|
||||
update_data = {
|
||||
key: value
|
||||
for key, value in payload.model_dump(exclude_none=True).items()
|
||||
if key != 'config_id'
|
||||
}
|
||||
|
||||
# 调用服务层更新配置
|
||||
config = forget_service.update_forgetting_config(
|
||||
db=db,
|
||||
config_id=payload.config_id,
|
||||
update_fields=update_data
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"更新遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingStatsResponse(**stats)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘引擎统计失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
|
||||
|
||||
Args:
|
||||
request: 遗忘曲线请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: "
|
||||
f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层生成遗忘曲线
|
||||
result = await forget_service.get_forgetting_curve(
|
||||
db=db,
|
||||
importance_score=request.importance_score,
|
||||
days=request.days,
|
||||
config_id=request.config_id
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
curve_points = [
|
||||
ForgettingCurvePoint(**point)
|
||||
for point in result['curve_data']
|
||||
]
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingCurveResponse(
|
||||
curve_data=curve_points,
|
||||
config=result['config']
|
||||
)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘曲线失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e))
|
||||
@@ -1,255 +0,0 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.memory_perceptual_model import PerceptualType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualFilter
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/perceptual",
|
||||
tags=["Perceptual Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
return success(
|
||||
data=count_stats,
|
||||
msg="Memory statistics retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=visual_memory,
|
||||
msg="Latest visual memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=audio_memory,
|
||||
msg="Latest audio memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=text_memory,
|
||||
msg="Latest text memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Timeline data of perceptual memories
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
query = PerceptualQuerySchema(
|
||||
filter=PerceptualFilter(type=perceptual_type),
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
f"returned={len(timeline_data.memories)}"
|
||||
)
|
||||
|
||||
return success(
|
||||
data=timeline_data.model_dump(),
|
||||
msg="Perceptual memory timeline retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch perceptual memory timeline",
|
||||
)
|
||||
@@ -1,43 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
)
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_result=long_term.get_long_databasets()
|
||||
|
||||
entity_result = await search_entity(end_user_id)
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
@@ -8,7 +9,12 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
@@ -16,6 +22,8 @@ from app.schemas.memory_storage_schema import (
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
GenerateCacheRequest,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
@@ -230,8 +238,28 @@ def update_config_extracted(
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
@@ -255,6 +283,28 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/work",
|
||||
tags=["Working Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve the message history for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation to fetch messages from.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the list of messages in the conversation.
|
||||
|
||||
Notes:
|
||||
- Uses ConversationService to fetch messages.
|
||||
- Consider paginating results if message history is large.
|
||||
- Logging can be added for audit and debugging.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
messages_obj = conversation_service.get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve detailed information about a specific conversation.
|
||||
|
||||
This endpoint will fetch the conversation detail for the user. If the detail
|
||||
does not exist or is outdated, it will trigger the LLM to generate a new summary.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation.
|
||||
current_user (User, optional): The authenticated user making the request.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the conversation detail serialized as a dictionary.
|
||||
|
||||
Notes:
|
||||
- Uses async ConversationService to fetch or generate the conversation detail.
|
||||
- Handles workspace and user-specific context automatically.
|
||||
- Logging and exception handling should be implemented for production monitoring.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
detail = await conversation_service.get_conversation_detail(
|
||||
user=current_user,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
return success(data=detail.model_dump(), msg="get conversation detail success")
|
||||
@@ -74,7 +74,7 @@ def get_multi_agent_configs(
|
||||
"app_id": str(app_id),
|
||||
"default_model_config_id": None,
|
||||
"model_parameters": None,
|
||||
"orchestration_mode": "supervisor",
|
||||
"orchestration_mode": "conditional",
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": {
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
@@ -72,12 +70,12 @@ def get_prompt_session(
|
||||
SessionMessage(role=role, content=content)
|
||||
for role, content in history
|
||||
]
|
||||
|
||||
|
||||
result = SessionHistoryResponse(
|
||||
session_id=session_id,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@@ -106,32 +104,35 @@ async def get_prompt_opt(
|
||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
|
||||
async def event_generator():
|
||||
yield "event:start\ndata: {}\n\n"
|
||||
try:
|
||||
async for chunk in service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.USER,
|
||||
content=data.message
|
||||
)
|
||||
opt_result = await service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=opt_result.desc
|
||||
)
|
||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
||||
result = {
|
||||
"prompt": opt_result.prompt,
|
||||
"desc": opt_result.desc,
|
||||
"variables": variables
|
||||
}
|
||||
result_schema = OptimizePromptResponse.model_validate(result)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -18,8 +17,6 @@ from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
@@ -268,8 +265,7 @@ def get_conversation(
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -311,7 +307,7 @@ async def chat(
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
|
||||
appid=share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
@@ -365,9 +361,6 @@ async def chat(
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
pass
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -396,96 +389,15 @@ async def chat(
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=agent_config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=agent_config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# config = workflow_config_4_app_release(release)
|
||||
config = multi_agent_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
async for event in service.chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
@@ -503,89 +415,37 @@ async def chat(
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
# 非流式返回
|
||||
result = await service.chat(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
|
||||
config = workflow_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
async for event in service.multi_agent_chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data, default=str, ensure_ascii=False)}\n\n"
|
||||
yield sse_message
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -598,34 +458,22 @@ async def chat(
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
result = await service.multi_agent_chat(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
@@ -4,17 +4,14 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
from . import app_api_controller, rag_api_controller, memory_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
service_router.include_router(app_api_controller.router)
|
||||
service_router.include_router(rag_api_knowledge_controller.router)
|
||||
service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(rag_api_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""App 服务接口 - 基于 API Key 认证"""
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
@@ -22,7 +21,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
@@ -138,10 +137,10 @@ async def chat(
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
print("="*50)
|
||||
print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
print(agent_config.default_model_config_id)
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -154,8 +153,7 @@ async def chat(
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -179,13 +177,12 @@ async def chat(
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
config = dict_to_multi_agent_config(app.current_release.config,app.id)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
@@ -197,8 +194,8 @@ async def chat(
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -214,6 +211,7 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
@@ -228,29 +226,22 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
config = dict_to_workflow_config(app.current_release.config,app.id)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -262,34 +253,23 @@ async def chat(
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
# 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import chunk_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.models.chunk import QAChunk
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/chunks", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_preview_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
|
||||
):
|
||||
"""
|
||||
Paged query document block preview list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_preview_chunks(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
|
||||
):
|
||||
"""
|
||||
Paged query document chunk list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_chunks(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
|
||||
):
|
||||
"""
|
||||
create chunk
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = chunk_schema.ChunkCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.create_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve document chunk information based on doc_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.get_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
|
||||
):
|
||||
"""
|
||||
Update document chunk content
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = chunk_schema.ChunkUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.update_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
delete document chunk
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/retrieve_type", response_model=ApiResponse)
|
||||
def get_retrieve_types():
|
||||
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
|
||||
|
||||
|
||||
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def retrieve_chunks(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
query: str = Body(..., description="question"),
|
||||
):
|
||||
"""
|
||||
retrieve chunk
|
||||
"""
|
||||
body = await request.json()
|
||||
retrieve_data = chunk_schema.ChunkRetrieve(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.retrieve_chunks(retrieve_data=retrieve_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
16
api/app/controllers/service/rag_api_controller.py
Normal file
16
api/app/controllers/service/rag_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_knowledge():
|
||||
"""列出可访问的知识库(占位)"""
|
||||
return success(data=[], msg="RAG API - Coming Soon")
|
||||
@@ -1,172 +0,0 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import document_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_documents(
|
||||
kb_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
document_ids: Optional[str] = Query(None, description="document ids, separated by commas")
|
||||
):
|
||||
"""
|
||||
Paged query document list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.get_documents(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
document_ids=document_ids,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/document", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_document(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
kb_id: uuid.UUID = Body(..., description="kb id"),
|
||||
file_name: str = Body(..., description="file name"),
|
||||
):
|
||||
"""
|
||||
create document
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = document_schema.DocumentCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.create_document(create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve document information based on document_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.get_document(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file_name: str = Body(None, description="file name (optional)"),
|
||||
):
|
||||
"""
|
||||
Update document information
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = document_schema.DocumentUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.update_document(document_id=document_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_document(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete document
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.delete_document(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{document_id}/chunks", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def parse_documents(
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
parse document
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await document_controller.parse_documents(document_id=document_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query, File, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import file_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_files(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
):
|
||||
"""
|
||||
Paged query file list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id=api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.get_files(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/folder", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_folder(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
folder_name: str = '/'
|
||||
):
|
||||
"""
|
||||
Create a new folder
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.create_folder(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
folder_name=folder_name,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
upload file
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.upload_file(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
file=file,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/customtext", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def custom_text(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
title: str = Body(..., description="title"),
|
||||
content: str = Body(..., description="content"),
|
||||
):
|
||||
"""
|
||||
custom text
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = file_schema.CustomTextFileCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.custom_text(kb_id=kb_id,
|
||||
parent_id=parent_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
async def get_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Download the file based on the file_id
|
||||
- Query file information from the database
|
||||
- Construct the file path and check if it exists
|
||||
- Return a FileResponse to download the file
|
||||
"""
|
||||
return await file_controller.get_file(file_id=file_id,
|
||||
db=db)
|
||||
|
||||
|
||||
@router.put("/{file_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_file(
|
||||
file_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
file_name: str = Body(None, description="file name (optional)"),
|
||||
):
|
||||
"""
|
||||
Update file information (such as file name)
|
||||
- Only specified fields such as file_name are allowed to be modified
|
||||
"""
|
||||
body = await request.json()
|
||||
update_data = file_schema.FileUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.update_file(file_id=file_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await file_controller.delete_file(file_id=file_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
from typing import Optional, Dict
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.controllers import knowledge_controller
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.models import knowledge_model
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import api_key_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/knowledges", tags=["V1 - RAG API"])
|
||||
api_logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/knowledgetype", response_model=ApiResponse)
|
||||
def get_knowledge_types():
|
||||
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
|
||||
|
||||
|
||||
@router.get("/permissiontype", response_model=ApiResponse)
|
||||
def get_permission_types():
|
||||
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
|
||||
|
||||
|
||||
@router.get("/parsertype", response_model=ApiResponse)
|
||||
def get_parser_types():
|
||||
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||
|
||||
|
||||
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge_graph_entity_types(
|
||||
llm_id: uuid.UUID,
|
||||
scenario: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
get knowledge graph entity types based on llm_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge_graph_entity_types(llm_id=llm_id,
|
||||
scenario=scenario,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/knowledges", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledges(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
|
||||
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas")
|
||||
):
|
||||
"""
|
||||
Query the knowledge base list in pages
|
||||
- Support filtering by parent_id
|
||||
- Support keyword search for knowledge base names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledges(parent_id=parent_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
keywords=keywords,
|
||||
kb_ids=kb_ids,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_knowledge(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
name: str = Body(..., description="KB name"),
|
||||
):
|
||||
"""
|
||||
create knowledge
|
||||
"""
|
||||
body = await request.json()
|
||||
create_data = knowledge_schema.KnowledgeCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.create_knowledge(create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge base information based on knowledge_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.put("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def update_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
name: str = Body(None, description="KB name (optional)"),
|
||||
):
|
||||
body = await request.json()
|
||||
update_data = knowledge_schema.KnowledgeUpdate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.update_knowledge(knowledge_id=knowledge_id,
|
||||
update_data=update_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{knowledge_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Soft-delete knowledge base
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.delete_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge_graph base information based on knowledge_id
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.get_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.delete("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def delete_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
delete knowledge graph
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.delete_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def rebuild_knowledge_graph(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
rebuild knowledge graph
|
||||
"""
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.rebuild_knowledge_graph(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import AppChatRequest
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -27,8 +28,6 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# ==================== 原有测试接口 ====================
|
||||
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
@@ -51,6 +50,7 @@ def test_llm(
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | llm
|
||||
answer = chain.invoke({"question": "What is LangChain?"})
|
||||
@@ -80,13 +80,13 @@ def test_embedding(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
embeddings = model.embed_documents(data)
|
||||
print(embeddings)
|
||||
query = "我想找一个适合学习的地方。"
|
||||
@@ -114,123 +114,13 @@ def test_rerank(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
query = "最近哪家咖啡店评价最好?"
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||
print(scores)
|
||||
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||
|
||||
|
||||
# ==================== Handoffs 测试接口 ====================
|
||||
|
||||
@router.post("/handoffs/{app_id}")
|
||||
async def test_handoffs(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: AppChatRequest = Body(...),
|
||||
current_user=Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试 Agent Handoffs 功能
|
||||
|
||||
演示 LangGraph 实现的多 Agent 协作和动态切换
|
||||
|
||||
- 从数据库 multi_agent_config 获取 Agent 配置
|
||||
- 根据用户问题自动切换到合适的 Agent
|
||||
- 使用 conversation_id 保持会话状态
|
||||
- 通过 stream 参数控制是否流式输出
|
||||
|
||||
事件类型(流式):
|
||||
- start: 开始执行
|
||||
- agent: 当前 Agent 信息
|
||||
- message: 流式消息内容
|
||||
- handoff: Agent 切换事件
|
||||
- end: 执行结束
|
||||
- error: 错误信息
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取或创建会话
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
if request.conversation_id:
|
||||
# 验证会话存在
|
||||
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
conversation_id = str(conversation.id)
|
||||
else:
|
||||
# 创建新会话
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=request.user_id,
|
||||
is_draft=True
|
||||
)
|
||||
conversation_id = str(conversation.id)
|
||||
|
||||
# 根据 stream 参数决定返回方式
|
||||
if request.stream:
|
||||
# 流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=True)
|
||||
return StreamingResponse(
|
||||
service.chat_stream(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
result = await service.chat(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return success(data=result, msg="Handoffs 测试成功")
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Handoffs 测试失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
|
||||
def get_handoff_agents(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""获取应用的 Handoff Agent 列表"""
|
||||
try:
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
agents = service.get_agents()
|
||||
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/handoffs/{app_id}/reset")
|
||||
def reset_handoff_service(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""重置指定应用的 Handoff 服务缓存"""
|
||||
reset_handoffs_service_cache(app_id)
|
||||
return success(msg="Handoff 服务已重置")
|
||||
|
||||
@@ -60,22 +60,6 @@ async def list_tools(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{tool_id}/methods", response_model=ApiResponse)
|
||||
async def get_tool_methods(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""获取工具的所有方法"""
|
||||
try:
|
||||
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ApiResponse)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
@@ -175,8 +159,7 @@ async def execute_tool(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
timeout=request.timeout
|
||||
)
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=400, detail=result["error"])
|
||||
|
||||
return success(
|
||||
data={
|
||||
"success": result.success,
|
||||
@@ -215,8 +198,8 @@ async def sync_mcp_tools(
|
||||
"""同步MCP工具列表"""
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -11,16 +11,14 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_node_statistics,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
@@ -43,27 +41,24 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""获取缓存的记忆洞察报告"""
|
||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||
@@ -71,27 +66,24 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""获取缓存的用户摘要"""
|
||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
@@ -105,35 +97,35 @@ async def generate_cache_api(
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
group_id = request.end_user_id
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
@@ -141,7 +133,7 @@ async def generate_cache_api(
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
|
||||
# 收集错误信息
|
||||
if not insight_result["success"]:
|
||||
result["errors"].append({
|
||||
@@ -153,29 +145,29 @@ async def generate_cache_api(
|
||||
"type": "summary",
|
||||
"error": summary_result.get("error")
|
||||
})
|
||||
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
else:
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 批量生成完成: "
|
||||
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="批量生成完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
|
||||
@@ -188,18 +180,18 @@ async def get_node_statistics_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
result = await analytics_memory_types(db, end_user_id)
|
||||
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
@@ -219,31 +211,31 @@ async def get_graph_data_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if limit > 1000:
|
||||
limit = 1000
|
||||
api_logger.warning("limit 参数超过最大值,已调整为 1000")
|
||||
|
||||
|
||||
if depth > 3:
|
||||
depth = 3
|
||||
api_logger.warning("depth 参数超过最大值,已调整为 3")
|
||||
|
||||
|
||||
# 解析 node_types 参数
|
||||
node_types_list = None
|
||||
if node_types:
|
||||
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await analytics_graph_data(
|
||||
db=db,
|
||||
@@ -253,19 +245,19 @@ async def get_graph_data_api(
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
@@ -278,25 +270,25 @@ async def get_end_user_profile(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -308,10 +300,10 @@ async def get_end_user_profile(
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
|
||||
return success(data=profile_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
@@ -325,56 +317,49 @@ async def update_end_user_profile(
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
|
||||
# 更新字段(只更新提供的字段,排除 end_user_id)
|
||||
# 允许 None 值来重置字段(如 hire_date)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
||||
if 'hire_date' in update_data:
|
||||
hire_date_timestamp = update_data['hire_date']
|
||||
if hire_date_timestamp is not None:
|
||||
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||
# 如果是 None,保持 None(允许清空)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(end_user, field, value)
|
||||
|
||||
|
||||
# 更新 updated_at 时间戳
|
||||
end_user.updated_at = datetime.datetime.now()
|
||||
|
||||
# 更新 updatetime_profile 为当前时间
|
||||
end_user.updatetime_profile = datetime.datetime.now()
|
||||
|
||||
|
||||
# 更新 updatetime_profile 为当前时间戳(毫秒)
|
||||
current_timestamp = int(datetime.datetime.now().timestamp() * 1000)
|
||||
end_user.updatetime_profile = current_timestamp
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -386,51 +371,11 @@ async def update_end_user_profile(
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
|
||||
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}")
|
||||
return success(data=profile_data.model_dump(), msg="更新成功")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
# 获取情绪数据
|
||||
emotion = MemoryEmotion(id, label)
|
||||
emotion_result = await emotion.get_emotion()
|
||||
|
||||
# 获取交互数据
|
||||
interaction = MemoryInteraction(id, label)
|
||||
interaction_result = await interaction.get_interaction_frequency()
|
||||
|
||||
# 关闭连接
|
||||
await emotion.close()
|
||||
await interaction.close()
|
||||
|
||||
result = {
|
||||
"emotion": emotion_result,
|
||||
"interaction": interaction_result
|
||||
}
|
||||
|
||||
api_logger.info(f"关系演变查询成功: id={id}, table={label}")
|
||||
return success(data=result, msg="关系演变")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e))
|
||||
|
||||
@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
@@ -96,7 +96,6 @@ async def create_workflow_config(
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
@@ -200,10 +199,10 @@ async def create_workflow_config(
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
@@ -244,11 +243,11 @@ async def delete_workflow_config(
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
@@ -313,12 +312,12 @@ async def validate_workflow_config(
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
@@ -366,10 +365,10 @@ async def get_workflow_executions(
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
@@ -418,14 +417,16 @@ async def get_workflow_execution(
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
@@ -486,22 +487,22 @@ async def run_workflow(
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
@@ -553,10 +554,10 @@ async def run_workflow(
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
@@ -601,7 +602,7 @@ async def cancel_workflow_execution(
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.code, msg=e.message)
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
|
||||
@@ -11,16 +11,10 @@ import os
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
@@ -165,13 +159,11 @@ class LangChainAgent:
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
retrieved_content=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
retrieved_content.append({query: aimessages})
|
||||
return messagss_list,retrieved_content
|
||||
return messagss_list
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
@@ -211,6 +203,7 @@ class LangChainAgent:
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
@@ -228,26 +221,11 @@ class LangChainAgent:
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
db_for_memory = next(get_db())
|
||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
print(retrieved_content)
|
||||
# 为长期记忆操作获取新的数据库连接
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
raise
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
history_term_memory=';'.join(history_term_memory)
|
||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
try:
|
||||
@@ -336,6 +314,10 @@ class LangChainAgent:
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
@@ -347,24 +329,14 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
db_for_memory = next(get_db())
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to long term memory: {e}")
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
try:
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
from app.schemas.api_key_schema import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
@@ -7,18 +7,17 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||
|
||||
|
||||
# Database configuration (Postgres)
|
||||
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||
@@ -38,7 +37,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||
@@ -49,7 +48,7 @@ class Settings:
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||
|
||||
|
||||
# Xinference configuration
|
||||
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||
|
||||
@@ -58,17 +57,17 @@ class Settings:
|
||||
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||
|
||||
|
||||
# LLM Request Configuration
|
||||
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||
|
||||
|
||||
# JWT Token Configuration
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
@@ -87,19 +86,19 @@ class Settings:
|
||||
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
# Superuser settings (internal defaults)
|
||||
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||
|
||||
|
||||
# Generic File Upload (internal)
|
||||
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||
@@ -124,7 +123,7 @@ class Settings:
|
||||
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||
|
||||
|
||||
# Sensitive Data Filtering
|
||||
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||
|
||||
@@ -143,6 +142,7 @@ class Settings:
|
||||
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
@@ -150,27 +150,21 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
@@ -185,7 +179,7 @@ class Settings:
|
||||
if filename:
|
||||
return str(base_path / filename)
|
||||
return str(base_path)
|
||||
|
||||
|
||||
def ensure_memory_output_dir(self) -> None:
|
||||
"""
|
||||
Ensure the memory output directory exists.
|
||||
|
||||
@@ -82,13 +82,6 @@ class BizCode(IntEnum):
|
||||
MEMORY_WRITE_FAILED = 9501
|
||||
MEMORY_READ_FAILED = 9502
|
||||
MEMORY_CONFIG_NOT_FOUND = 9503
|
||||
|
||||
# Implicit Memory API(96xx)
|
||||
INVALID_USER_ID = 9601
|
||||
INSUFFICIENT_DATA = 9602
|
||||
INVALID_FILTER_PARAMS = 9603
|
||||
ANALYSIS_FAILED = 9604
|
||||
PROFILE_STORAGE_ERROR = 9605
|
||||
|
||||
# 系统(100xx)
|
||||
INTERNAL_ERROR = 10001
|
||||
@@ -110,24 +103,24 @@ HTTP_MAPPING = {
|
||||
BizCode.TOKEN_EXPIRED: 401,
|
||||
BizCode.TOKEN_BLACKLISTED: 401,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.TENANT_NOT_FOUND: 404,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.NOT_FOUND: 404,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 400,
|
||||
BizCode.FILE_NOT_FOUND: 400,
|
||||
BizCode.APP_NOT_FOUND: 400,
|
||||
BizCode.RELEASE_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 404,
|
||||
BizCode.MODEL_NOT_FOUND: 404,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 404,
|
||||
BizCode.DOCUMENT_NOT_FOUND: 404,
|
||||
BizCode.FILE_NOT_FOUND: 404,
|
||||
BizCode.APP_NOT_FOUND: 404,
|
||||
BizCode.RELEASE_NOT_FOUND: 404,
|
||||
BizCode.DUPLICATE_NAME: 409,
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
|
||||
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
|
||||
BizCode.AGENT_CONFIG_MISSING: 400,
|
||||
BizCode.SHARE_DISABLED: 403,
|
||||
@@ -166,13 +159,6 @@ HTTP_MAPPING = {
|
||||
BizCode.MEMORY_READ_FAILED: 500,
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
|
||||
|
||||
# Implicit Memory API 错误码映射
|
||||
BizCode.INVALID_USER_ID: 400,
|
||||
BizCode.INSUFFICIENT_DATA: 400,
|
||||
BizCode.INVALID_FILTER_PARAMS: 400,
|
||||
BizCode.ANALYSIS_FAILED: 500,
|
||||
BizCode.PROFILE_STORAGE_ERROR: 500,
|
||||
|
||||
BizCode.INTERNAL_ERROR: 500,
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
|
||||
@@ -106,32 +106,28 @@ class SearchService:
|
||||
limit: int = 15,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search with two-stage ranking.
|
||||
|
||||
Stage 1: Filter by content relevance (BM25 + Embedding)
|
||||
Stage 2: Rerank by activation values (ACTR)
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering
|
||||
group_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Max results per category (default: 15)
|
||||
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: BM25 weight (default: 0.6)
|
||||
activation_boost_factor: Activation impact on memory strength (default: 0.8)
|
||||
output_path: JSON output path (default: "search_results.json")
|
||||
return_raw_results: Return full metadata (default: False)
|
||||
memory_config: MemoryConfig for embedding model
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
@@ -155,7 +151,6 @@ class SearchService:
|
||||
output_path=output_path,
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
|
||||
@@ -425,9 +425,15 @@ async def Input_Summary(
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
@@ -533,11 +539,31 @@ async def Input_Summary(
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
# Return retrieved information directly without LLM processing
|
||||
# Use the raw retrieved info as the answer
|
||||
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='input_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
|
||||
@@ -5,16 +5,19 @@ This module provides analytics and insights for the memory system.
|
||||
|
||||
Available functions:
|
||||
- get_hot_memory_tags: Get hot memory tags by frequency
|
||||
- MemoryInsight: Generate memory insight reports
|
||||
- get_recent_activity_stats: Get recent activity statistics
|
||||
|
||||
Note: MemoryInsight and generate_user_summary have been moved to
|
||||
app.services.user_memory_service for better architecture.
|
||||
- generate_user_summary: Generate user summary
|
||||
"""
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
|
||||
__all__ = [
|
||||
"get_hot_memory_tags",
|
||||
"MemoryInsight",
|
||||
"get_recent_activity_stats",
|
||||
"generate_user_summary",
|
||||
]
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Implicit Memory Module
|
||||
|
||||
This module provides behavior analysis capabilities that build comprehensive user profiles
|
||||
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
|
||||
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
|
||||
"""
|
||||
@@ -1 +0,0 @@
|
||||
"""Analyzers package for implicit memory analysis components."""
|
||||
@@ -1,271 +0,0 @@
|
||||
"""Dimension Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based personality dimension analysis from user memory summaries.
|
||||
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
|
||||
providing percentage scores with evidence and reasoning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
DimensionPortrait,
|
||||
DimensionScore,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DimensionData(BaseModel):
|
||||
"""Individual dimension analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
reasoning: str = ""
|
||||
confidence_level: int = 50 # Default to medium confidence
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DimensionAnalyzer:
|
||||
"""Analyzes user memory summaries to extract personality dimensions."""
|
||||
|
||||
# Define the four dimensions we analyze
|
||||
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the dimension analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_portrait: Optional[DimensionPortrait] = None
|
||||
) -> DimensionPortrait:
|
||||
"""Analyze user summaries to extract personality dimensions.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_portrait: Optional existing portrait for incremental updates
|
||||
|
||||
Returns:
|
||||
Dimension portrait with four personality dimensions
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_portrait(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_dimensions(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create dimension scores
|
||||
dimension_scores = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for dimension_name in self.DIMENSIONS:
|
||||
# Handle response as dictionary
|
||||
dimensions_data = response.get("dimensions", {})
|
||||
dimension_data = dimensions_data.get(dimension_name)
|
||||
|
||||
if dimension_data:
|
||||
# Validate and create dimension score
|
||||
score = self._create_dimension_score(
|
||||
dimension_name=dimension_name,
|
||||
dimension_data=dimension_data
|
||||
)
|
||||
dimension_scores[dimension_name] = score
|
||||
else:
|
||||
# Create default score if missing
|
||||
logger.warning(f"Missing dimension data for {dimension_name}, using default")
|
||||
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
literature=dimension_scores["literature"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries),
|
||||
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
|
||||
)
|
||||
|
||||
logger.info(f"Created dimension portrait for user {user_id}")
|
||||
return portrait
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
def _create_dimension_score(
|
||||
self,
|
||||
dimension_name: str,
|
||||
dimension_data: dict
|
||||
) -> DimensionScore:
|
||||
"""Create a dimension score from analysis data.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
dimension_data: Analysis data dictionary for the dimension
|
||||
|
||||
Returns:
|
||||
DimensionScore object
|
||||
"""
|
||||
# Validate percentage - handle dict access
|
||||
percentage = dimension_data.get("percentage", 0.0)
|
||||
percentage = max(0.0, min(100.0, float(percentage)))
|
||||
|
||||
# Validate confidence level
|
||||
confidence_level = self._validate_confidence_level(dimension_data.get("confidence_level", 50))
|
||||
|
||||
# Ensure evidence is not empty
|
||||
evidence = dimension_data.get("evidence", [])
|
||||
if not evidence:
|
||||
evidence = ["No specific evidence found"]
|
||||
|
||||
# Ensure reasoning is not empty
|
||||
reasoning = dimension_data.get("reasoning", "")
|
||||
if not reasoning:
|
||||
reasoning = f"Analysis for {dimension_name} dimension"
|
||||
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=percentage,
|
||||
evidence=evidence,
|
||||
reasoning=reasoning,
|
||||
confidence_level=confidence_level
|
||||
)
|
||||
|
||||
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
|
||||
"""Create a default dimension score when analysis fails.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
|
||||
Returns:
|
||||
Default DimensionScore object
|
||||
"""
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=0.0,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
reasoning=f"No clear evidence found for {dimension_name} dimension",
|
||||
confidence_level=20 # Low confidence as numerical value
|
||||
)
|
||||
|
||||
def _validate_confidence_level(self, confidence_level) -> int:
|
||||
"""Return confidence level as integer, handling both string and numeric inputs.
|
||||
|
||||
Args:
|
||||
confidence_level: Confidence level (string or numeric)
|
||||
|
||||
Returns:
|
||||
Confidence level as integer (0-100)
|
||||
"""
|
||||
# If it's already a number, return it as int
|
||||
if isinstance(confidence_level, (int, float)):
|
||||
return int(confidence_level)
|
||||
|
||||
# If it's a string, convert common values to numbers
|
||||
if isinstance(confidence_level, str):
|
||||
confidence_str = confidence_level.lower().strip()
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return 85
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return 50
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return 20
|
||||
else:
|
||||
# Try to parse as number
|
||||
try:
|
||||
return int(float(confidence_str))
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
|
||||
return 50
|
||||
|
||||
# Default fallback
|
||||
return 50
|
||||
|
||||
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
literature=self._create_default_dimension_score("literature"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0,
|
||||
historical_trends=None
|
||||
)
|
||||
|
||||
def _calculate_historical_trends(
|
||||
self,
|
||||
existing_portrait: DimensionPortrait
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Calculate historical trends from existing portrait.
|
||||
|
||||
Args:
|
||||
existing_portrait: Previous dimension portrait
|
||||
|
||||
Returns:
|
||||
List of historical trend data
|
||||
"""
|
||||
if not existing_portrait:
|
||||
return []
|
||||
|
||||
# Create trend entry from existing portrait
|
||||
trend_entry = {
|
||||
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
|
||||
"creativity": existing_portrait.creativity.percentage,
|
||||
"aesthetic": existing_portrait.aesthetic.percentage,
|
||||
"technology": existing_portrait.technology.percentage,
|
||||
"literature": existing_portrait.literature.percentage,
|
||||
"total_summaries": existing_portrait.total_summaries_analyzed
|
||||
}
|
||||
|
||||
# Combine with existing trends
|
||||
existing_trends = existing_portrait.historical_trends or []
|
||||
|
||||
# Keep only recent trends (last 10 entries)
|
||||
all_trends = existing_trends + [trend_entry]
|
||||
return all_trends[-10:]
|
||||
@@ -1,459 +0,0 @@
|
||||
"""Habit Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based behavioral habit analysis from user memory summaries.
|
||||
It identifies recurring behavioral patterns, temporal patterns, and consolidates
|
||||
similar habits with confidence scoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitData(BaseModel):
|
||||
"""Individual habit analysis data."""
|
||||
habit_description: str
|
||||
frequency_pattern: str
|
||||
time_context: str
|
||||
confidence_level: int = 50 # Default to medium confidence
|
||||
supporting_summaries: List[str] = Field(default_factory=list)
|
||||
specific_examples: List[str] = Field(default_factory=list)
|
||||
is_current: bool = True
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[HabitData] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HabitAnalyzer:
|
||||
"""Analyzes user memory summaries to extract behavioral habits."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the habit analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Analyze user summaries to extract behavioral habits.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_habits(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to BehaviorHabit objects
|
||||
behavior_habits = []
|
||||
|
||||
for habit_data in response.get("habits", []):
|
||||
try:
|
||||
# Handle habit_data as dictionary
|
||||
supporting_summaries = habit_data.get("supporting_summaries", [])
|
||||
specific_examples = habit_data.get("specific_examples", [])
|
||||
|
||||
# Determine observation dates from summaries
|
||||
first_observed, last_observed = self._determine_observation_dates(
|
||||
user_summaries, supporting_summaries
|
||||
)
|
||||
|
||||
behavior_habit = BehaviorHabit(
|
||||
habit_description=habit_data.get("habit_description", ""),
|
||||
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
|
||||
time_context=habit_data.get("time_context", ""),
|
||||
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", 50)),
|
||||
specific_examples=specific_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=habit_data.get("is_current", True)
|
||||
)
|
||||
|
||||
# Validate habit
|
||||
if self._is_valid_habit(behavior_habit):
|
||||
behavior_habits.append(behavior_habit)
|
||||
else:
|
||||
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating behavior habit: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing habits if provided
|
||||
if existing_habits:
|
||||
behavior_habits = self._consolidate_habits(
|
||||
new_habits=behavior_habits,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Sort habits by confidence and recency
|
||||
behavior_habits = self._sort_habits_by_priority(behavior_habits)
|
||||
|
||||
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
|
||||
return behavior_habits
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
|
||||
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
|
||||
"""Validate and convert frequency pattern string.
|
||||
|
||||
Args:
|
||||
frequency_str: Frequency pattern as string
|
||||
|
||||
Returns:
|
||||
FrequencyPattern enum value
|
||||
"""
|
||||
frequency_str = frequency_str.lower().strip()
|
||||
|
||||
frequency_mapping = {
|
||||
"daily": FrequencyPattern.DAILY,
|
||||
"weekly": FrequencyPattern.WEEKLY,
|
||||
"monthly": FrequencyPattern.MONTHLY,
|
||||
"seasonal": FrequencyPattern.SEASONAL,
|
||||
"occasional": FrequencyPattern.OCCASIONAL,
|
||||
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
}
|
||||
|
||||
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
|
||||
|
||||
def _validate_confidence_level(self, confidence_level) -> int:
|
||||
"""Return confidence level as integer, handling both string and numeric inputs.
|
||||
|
||||
Args:
|
||||
confidence_level: Confidence level (string or numeric)
|
||||
|
||||
Returns:
|
||||
Confidence level as integer (0-100)
|
||||
"""
|
||||
# If it's already a number, return it as int
|
||||
if isinstance(confidence_level, (int, float)):
|
||||
return int(confidence_level)
|
||||
|
||||
# If it's a string, convert common values to numbers
|
||||
if isinstance(confidence_level, str):
|
||||
confidence_str = confidence_level.lower().strip()
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return 85
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return 50
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return 20
|
||||
else:
|
||||
# Try to parse as number
|
||||
try:
|
||||
return int(float(confidence_str))
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
|
||||
return 50
|
||||
|
||||
# Default fallback
|
||||
return 50
|
||||
|
||||
def _determine_observation_dates(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
supporting_summary_ids: List[str]
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Determine first and last observation dates for a habit.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user summaries
|
||||
supporting_summary_ids: IDs of summaries supporting the habit
|
||||
|
||||
Returns:
|
||||
Tuple of (first_observed, last_observed) dates
|
||||
"""
|
||||
from datetime import timezone
|
||||
|
||||
# Find summaries that support this habit
|
||||
supporting_summaries = [
|
||||
summary for summary in user_summaries
|
||||
if summary.summary_id in supporting_summary_ids
|
||||
]
|
||||
|
||||
if not supporting_summaries:
|
||||
# Use all summaries if no specific supporting summaries found
|
||||
supporting_summaries = user_summaries
|
||||
|
||||
if not supporting_summaries:
|
||||
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return current_time, current_time
|
||||
|
||||
# Get date range from supporting summaries - normalize to naive datetimes
|
||||
timestamps = []
|
||||
for summary in supporting_summaries:
|
||||
ts = summary.timestamp
|
||||
# Convert to naive datetime if it's timezone-aware
|
||||
if ts.tzinfo is not None:
|
||||
ts = ts.replace(tzinfo=None)
|
||||
timestamps.append(ts)
|
||||
|
||||
first_observed = min(timestamps)
|
||||
last_observed = max(timestamps)
|
||||
|
||||
return first_observed, last_observed
|
||||
|
||||
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
|
||||
"""Validate a behavioral habit.
|
||||
|
||||
Args:
|
||||
habit: Behavioral habit to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not habit.habit_description or not habit.habit_description.strip():
|
||||
return False
|
||||
|
||||
# Check time context
|
||||
if not habit.time_context or not habit.time_context.strip():
|
||||
return False
|
||||
|
||||
# Check supporting summaries
|
||||
if not habit.specific_examples or len(habit.specific_examples) == 0:
|
||||
return False
|
||||
|
||||
# Check specific examples
|
||||
if not habit.specific_examples or len(habit.specific_examples) == 0:
|
||||
return False
|
||||
|
||||
# Check observation dates
|
||||
if habit.first_observed > habit.last_observed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating habit: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_habits(
|
||||
self,
|
||||
new_habits: List[BehaviorHabit],
|
||||
existing_habits: List[BehaviorHabit],
|
||||
similarity_threshold: float = 0.7
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Consolidate new habits with existing ones.
|
||||
|
||||
Args:
|
||||
new_habits: Newly extracted habits
|
||||
existing_habits: Existing habits
|
||||
similarity_threshold: Threshold for considering habits similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of habits
|
||||
"""
|
||||
consolidated = existing_habits.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_habit in new_habits:
|
||||
# Find similar existing habit
|
||||
similar_habit = self._find_similar_habit(
|
||||
new_habit, existing_habits, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_habit:
|
||||
# Update existing habit
|
||||
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, habit in enumerate(consolidated):
|
||||
if habit.habit_description == similar_habit.habit_description:
|
||||
consolidated[i] = updated_habit
|
||||
break
|
||||
else:
|
||||
# Add as new habit
|
||||
consolidated.append(new_habit)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_habit(
|
||||
self,
|
||||
target_habit: BehaviorHabit,
|
||||
existing_habits: List[BehaviorHabit],
|
||||
threshold: float
|
||||
) -> Optional[BehaviorHabit]:
|
||||
"""Find similar habit in existing list.
|
||||
|
||||
Args:
|
||||
target_habit: Habit to find similarity for
|
||||
existing_habits: List of existing habits
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar habit if found, None otherwise
|
||||
"""
|
||||
target_desc = target_habit.habit_description.lower().strip()
|
||||
|
||||
for existing_habit in existing_habits:
|
||||
existing_desc = existing_habit.habit_description.lower().strip()
|
||||
|
||||
# Check description similarity
|
||||
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
|
||||
|
||||
# Check frequency pattern match
|
||||
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
|
||||
|
||||
# Check time context similarity
|
||||
time_similarity = self._calculate_text_similarity(
|
||||
target_habit.time_context.lower(),
|
||||
existing_habit.time_context.lower()
|
||||
)
|
||||
|
||||
# Combined similarity score
|
||||
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
|
||||
if frequency_match:
|
||||
combined_similarity += 0.1 # Bonus for frequency match
|
||||
|
||||
if combined_similarity >= threshold:
|
||||
return existing_habit
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_habits(
|
||||
self,
|
||||
existing_habit: BehaviorHabit,
|
||||
new_habit: BehaviorHabit,
|
||||
current_time: datetime
|
||||
) -> BehaviorHabit:
|
||||
"""Merge two similar habits.
|
||||
|
||||
Args:
|
||||
existing_habit: Existing habit
|
||||
new_habit: New habit to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged behavioral habit
|
||||
"""
|
||||
# Combine supporting summaries (using specific_examples instead)
|
||||
combined_examples = list(set(
|
||||
existing_habit.specific_examples + new_habit.specific_examples
|
||||
))
|
||||
|
||||
# Combine specific examples
|
||||
combined_examples = list(set(
|
||||
existing_habit.specific_examples + new_habit.specific_examples
|
||||
))
|
||||
|
||||
# Update confidence level (take higher confidence)
|
||||
new_confidence = max(existing_habit.confidence_level, new_habit.confidence_level)
|
||||
|
||||
# Update observation dates
|
||||
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
|
||||
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
|
||||
|
||||
# Determine if habit is current (observed within last 30 days)
|
||||
is_current = (current_time - last_observed).days <= 30
|
||||
|
||||
# Combine time context
|
||||
combined_time_context = existing_habit.time_context
|
||||
if new_habit.time_context and new_habit.time_context not in combined_time_context:
|
||||
combined_time_context += f"; {new_habit.time_context}"
|
||||
|
||||
return BehaviorHabit(
|
||||
habit_description=existing_habit.habit_description, # Keep original description
|
||||
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
|
||||
time_context=combined_time_context,
|
||||
confidence_level=new_confidence,
|
||||
specific_examples=combined_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=is_current
|
||||
)
|
||||
|
||||
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
|
||||
"""Sort habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to sort
|
||||
|
||||
Returns:
|
||||
Sorted list of habits
|
||||
"""
|
||||
def priority_score(habit: BehaviorHabit) -> tuple:
|
||||
# Confidence level score (0-100 scale)
|
||||
confidence_score = habit.confidence_level
|
||||
|
||||
# Recency score (more recent = higher score)
|
||||
days_since_last = (datetime.now() - habit.last_observed).days
|
||||
recency_score = max(0, 365 - days_since_last) # Max 365 days
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 100 if habit.is_current else 0
|
||||
|
||||
return (confidence_score, recency_score + current_bonus, habit.last_observed)
|
||||
|
||||
return sorted(habits, key=priority_score, reverse=True)
|
||||
@@ -1,277 +0,0 @@
|
||||
"""Interest Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based interest area analysis from user memory summaries.
|
||||
It categorizes user interests into four areas: tech, lifestyle, music, and art,
|
||||
providing percentage distribution that totals 100%.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
InterestAreaDistribution,
|
||||
InterestCategory,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterestData(BaseModel):
|
||||
"""Individual interest category analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
trending_direction: Optional[str] = None
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalyzer:
|
||||
"""Analyzes user memory summaries to extract interest area distribution."""
|
||||
|
||||
# Define the four interest categories we analyze
|
||||
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the interest analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_distribution: Optional[InterestAreaDistribution] = None
|
||||
) -> InterestAreaDistribution:
|
||||
"""Analyze user summaries to extract interest area distribution.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_distribution: Optional existing distribution for trend tracking
|
||||
|
||||
Returns:
|
||||
Interest area distribution across four categories
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_distribution(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_interests(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create interest categories
|
||||
interest_categories = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
# Extract interest_distribution from response dict
|
||||
interest_distribution = response.get("interest_distribution", {})
|
||||
|
||||
# Extract and validate interest data
|
||||
raw_interests = {}
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data_dict = interest_distribution.get(category_name)
|
||||
if interest_data_dict:
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=interest_data_dict.get("percentage", 0.0),
|
||||
evidence=interest_data_dict.get("evidence", []),
|
||||
trending_direction=interest_data_dict.get("trending_direction")
|
||||
)
|
||||
else:
|
||||
# Create default if missing
|
||||
logger.warning(f"Missing interest data for {category_name}, using default")
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=0.0,
|
||||
evidence=["No specific evidence found"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
# Normalize percentages to ensure they sum to 100%
|
||||
normalized_interests = self._normalize_percentages(raw_interests)
|
||||
|
||||
# Create interest category objects
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data = normalized_interests[category_name]
|
||||
|
||||
# Calculate trending direction if we have existing data
|
||||
trending_direction = self._calculate_trending_direction(
|
||||
category_name=category_name,
|
||||
current_percentage=interest_data.percentage,
|
||||
existing_distribution=existing_distribution
|
||||
) if existing_distribution else interest_data.trending_direction
|
||||
|
||||
interest_categories[category_name] = InterestCategory(
|
||||
category_name=category_name,
|
||||
percentage=interest_data.percentage,
|
||||
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
|
||||
trending_direction=trending_direction
|
||||
)
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
art=interest_categories["art"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries)
|
||||
)
|
||||
|
||||
# Validate that percentages sum to 100%
|
||||
total_percentage = distribution.total_percentage
|
||||
if not (99.9 <= total_percentage <= 100.1):
|
||||
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
|
||||
|
||||
logger.info(f"Created interest distribution for user {user_id}")
|
||||
return distribution
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
|
||||
"""Normalize percentages to ensure they sum to 100%.
|
||||
|
||||
Args:
|
||||
raw_interests: Raw interest data with potentially unnormalized percentages
|
||||
|
||||
Returns:
|
||||
Normalized interest data
|
||||
"""
|
||||
# Calculate current total
|
||||
total = sum(interest.percentage for interest in raw_interests.values())
|
||||
|
||||
if total == 0:
|
||||
# If all percentages are 0, distribute equally
|
||||
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
|
||||
normalized = {}
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=equal_percentage,
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
return normalized
|
||||
|
||||
# Normalize to sum to 100%
|
||||
normalization_factor = 100.0 / total
|
||||
normalized = {}
|
||||
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized_percentage = interest_data.percentage * normalization_factor
|
||||
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=round(normalized_percentage, 1),
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
|
||||
# Handle rounding errors by adjusting the largest category
|
||||
current_total = sum(interest.percentage for interest in normalized.values())
|
||||
if abs(current_total - 100.0) > 0.1:
|
||||
# Find category with largest percentage and adjust
|
||||
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
|
||||
adjustment = 100.0 - current_total
|
||||
|
||||
adjusted_percentage = normalized[largest_category].percentage + adjustment
|
||||
normalized[largest_category] = InterestData(
|
||||
percentage=round(max(0.0, adjusted_percentage), 1),
|
||||
evidence=normalized[largest_category].evidence,
|
||||
trending_direction=normalized[largest_category].trending_direction
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
def _calculate_trending_direction(
|
||||
self,
|
||||
category_name: str,
|
||||
current_percentage: float,
|
||||
existing_distribution: InterestAreaDistribution,
|
||||
threshold: float = 5.0
|
||||
) -> Optional[str]:
|
||||
"""Calculate trending direction for an interest category.
|
||||
|
||||
Args:
|
||||
category_name: Name of the interest category
|
||||
current_percentage: Current percentage for the category
|
||||
existing_distribution: Previous distribution for comparison
|
||||
threshold: Minimum percentage change to consider a trend
|
||||
|
||||
Returns:
|
||||
Trending direction: "increasing", "decreasing", "stable", or None
|
||||
"""
|
||||
try:
|
||||
# Get previous percentage
|
||||
previous_category = getattr(existing_distribution, category_name, None)
|
||||
if not previous_category:
|
||||
return None
|
||||
|
||||
previous_percentage = previous_category.percentage
|
||||
change = current_percentage - previous_percentage
|
||||
|
||||
if abs(change) < threshold:
|
||||
return "stable"
|
||||
elif change > 0:
|
||||
return "increasing"
|
||||
else:
|
||||
return "decreasing"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trending direction for {category_name}: {e}")
|
||||
return None
|
||||
|
||||
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
art=default_category("art"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0
|
||||
)
|
||||
@@ -1,302 +0,0 @@
|
||||
"""Preference Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based preference extraction from user memory summaries.
|
||||
It identifies implicit preferences, consolidates similar preferences, and calculates
|
||||
confidence scores based on evidence strength.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
PreferenceTag,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PreferenceAnalyzer:
|
||||
"""Analyzes user memory summaries to extract implicit preferences."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the preference analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_preferences: Optional[List[PreferenceTag]] = None
|
||||
) -> List[PreferenceTag]:
|
||||
"""Analyze user summaries to extract preferences.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_preferences: Optional existing preferences for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted preference tags
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_preferences(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to PreferenceTag objects
|
||||
preference_tags = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for pref_data in response.get("preferences", []):
|
||||
try:
|
||||
# Extract conversation references from summaries
|
||||
conversation_refs = [s.summary_id for s in user_summaries]
|
||||
|
||||
preference_tag = PreferenceTag(
|
||||
tag_name=pref_data.get("tag_name", ""),
|
||||
confidence_score=float(pref_data.get("confidence_score", 0.0)),
|
||||
supporting_evidence=pref_data.get("supporting_evidence", []),
|
||||
context_details=pref_data.get("context_details", ""),
|
||||
category=pref_data.get("category"),
|
||||
conversation_references=conversation_refs,
|
||||
created_at=current_time,
|
||||
updated_at=current_time
|
||||
)
|
||||
|
||||
# Validate preference tag
|
||||
if self._is_valid_preference(preference_tag):
|
||||
preference_tags.append(preference_tag)
|
||||
else:
|
||||
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating preference tag: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing preferences if provided
|
||||
if existing_preferences:
|
||||
preference_tags = self._consolidate_preferences(
|
||||
new_preferences=preference_tags,
|
||||
existing_preferences=existing_preferences
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
|
||||
return preference_tags
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
|
||||
"""Validate a preference tag.
|
||||
|
||||
Args:
|
||||
preference: Preference tag to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not preference.tag_name or not preference.tag_name.strip():
|
||||
return False
|
||||
|
||||
# Check confidence score range
|
||||
if not (0.0 <= preference.confidence_score <= 1.0):
|
||||
return False
|
||||
|
||||
# Check supporting evidence
|
||||
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
|
||||
return False
|
||||
|
||||
# Check context details
|
||||
if not preference.context_details or not preference.context_details.strip():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating preference: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_preferences(
|
||||
self,
|
||||
new_preferences: List[PreferenceTag],
|
||||
existing_preferences: List[PreferenceTag],
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[PreferenceTag]:
|
||||
"""Consolidate new preferences with existing ones.
|
||||
|
||||
Args:
|
||||
new_preferences: Newly extracted preferences
|
||||
existing_preferences: Existing preferences
|
||||
similarity_threshold: Threshold for considering preferences similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of preferences
|
||||
"""
|
||||
consolidated = existing_preferences.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_pref in new_preferences:
|
||||
# Find similar existing preference
|
||||
similar_pref = self._find_similar_preference(
|
||||
new_pref, existing_preferences, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_pref:
|
||||
# Update existing preference
|
||||
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, pref in enumerate(consolidated):
|
||||
if pref.tag_name == similar_pref.tag_name:
|
||||
consolidated[i] = updated_pref
|
||||
break
|
||||
else:
|
||||
# Add as new preference
|
||||
consolidated.append(new_pref)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_preference(
|
||||
self,
|
||||
target_preference: PreferenceTag,
|
||||
existing_preferences: List[PreferenceTag],
|
||||
threshold: float
|
||||
) -> Optional[PreferenceTag]:
|
||||
"""Find similar preference in existing list.
|
||||
|
||||
Args:
|
||||
target_preference: Preference to find similarity for
|
||||
existing_preferences: List of existing preferences
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar preference if found, None otherwise
|
||||
"""
|
||||
target_name = target_preference.tag_name.lower().strip()
|
||||
|
||||
for existing_pref in existing_preferences:
|
||||
existing_name = existing_pref.tag_name.lower().strip()
|
||||
|
||||
# Simple similarity check based on common words
|
||||
similarity = self._calculate_text_similarity(target_name, existing_name)
|
||||
|
||||
if similarity >= threshold:
|
||||
return existing_pref
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_preferences(
|
||||
self,
|
||||
existing_pref: PreferenceTag,
|
||||
new_pref: PreferenceTag,
|
||||
current_time: datetime
|
||||
) -> PreferenceTag:
|
||||
"""Merge two similar preferences.
|
||||
|
||||
Args:
|
||||
existing_pref: Existing preference
|
||||
new_pref: New preference to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged preference tag
|
||||
"""
|
||||
# Combine supporting evidence
|
||||
combined_evidence = list(set(
|
||||
existing_pref.supporting_evidence + new_pref.supporting_evidence
|
||||
))
|
||||
|
||||
# Combine conversation references
|
||||
combined_refs = list(set(
|
||||
existing_pref.conversation_references + new_pref.conversation_references
|
||||
))
|
||||
|
||||
# Calculate new confidence score (weighted average)
|
||||
evidence_weight = len(new_pref.supporting_evidence)
|
||||
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
|
||||
|
||||
if total_weight > 0:
|
||||
new_confidence = (
|
||||
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
|
||||
new_pref.confidence_score * evidence_weight) / total_weight
|
||||
)
|
||||
else:
|
||||
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
|
||||
|
||||
# Ensure confidence doesn't exceed 1.0
|
||||
new_confidence = min(new_confidence, 1.0)
|
||||
|
||||
# Combine context details
|
||||
combined_context = existing_pref.context_details
|
||||
if new_pref.context_details and new_pref.context_details not in combined_context:
|
||||
combined_context += f"; {new_pref.context_details}"
|
||||
|
||||
return PreferenceTag(
|
||||
tag_name=existing_pref.tag_name, # Keep original name
|
||||
confidence_score=new_confidence,
|
||||
supporting_evidence=combined_evidence,
|
||||
context_details=combined_context,
|
||||
category=existing_pref.category or new_pref.category,
|
||||
conversation_references=combined_refs,
|
||||
created_at=existing_pref.created_at,
|
||||
updated_at=current_time
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
Memory Data Source
|
||||
|
||||
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryDataSource:
|
||||
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
neo4j_connector: Optional[Neo4jConnector] = None
|
||||
):
|
||||
self.db = db
|
||||
self.neo4j_connector = neo4j_connector or Neo4jConnector()
|
||||
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
|
||||
|
||||
def _parse_timestamp(self, timestamp: Any) -> datetime:
|
||||
"""Parse timestamp from various formats."""
|
||||
if isinstance(timestamp, str):
|
||||
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
elif timestamp is None:
|
||||
return datetime.now()
|
||||
return timestamp
|
||||
|
||||
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
|
||||
"""Convert a Neo4j dict directly to UserMemorySummary."""
|
||||
try:
|
||||
content = summary_dict.get("content", summary_dict.get("summary", ""))
|
||||
if not content or not content.strip():
|
||||
return None
|
||||
|
||||
return UserMemorySummary(
|
||||
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
|
||||
user_id=user_id,
|
||||
user_content=content,
|
||||
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
|
||||
confidence_score=1.0,
|
||||
summary_type="memory_summary"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_summaries(
|
||||
self,
|
||||
user_id: str,
|
||||
time_range: Optional[TimeRange] = None,
|
||||
limit: int = 1000
|
||||
) -> List[UserMemorySummary]:
|
||||
"""Retrieve user memory summaries from Neo4j.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
time_range: Optional time range filter
|
||||
limit: Maximum number of summaries
|
||||
|
||||
Returns:
|
||||
List of user memory summaries
|
||||
"""
|
||||
try:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for summary_dict in summary_dicts:
|
||||
summary = self._dict_to_user_summary(summary_dict, user_id)
|
||||
if summary:
|
||||
summaries.append(summary)
|
||||
|
||||
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
|
||||
return summaries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Habit Detector for Implicit Memory System
|
||||
|
||||
This module implements the HabitDetector class that specializes in identifying
|
||||
and ranking behavioral habits from user memory summaries. It provides advanced
|
||||
habit analysis with confidence scoring, recency weighting, and current vs past
|
||||
habit distinction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
|
||||
HabitAnalyzer,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitDetector:
|
||||
"""Detects and ranks behavioral habits from user memory summaries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
llm_model_id: Optional[str] = None
|
||||
):
|
||||
"""Initialize the habit detector.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
|
||||
|
||||
async def detect_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Detect behavioral habits from user summaries.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of detected and ranked behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If habit analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
try:
|
||||
# Use the habit analyzer to extract habits
|
||||
detected_habits = await self.habit_analyzer.analyze_habits(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Apply advanced ranking and filtering
|
||||
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
|
||||
|
||||
# Distinguish current vs past habits
|
||||
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
|
||||
|
||||
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
|
||||
return categorized_habits
|
||||
|
||||
except LLMClientException:
|
||||
logger.error(f"Habit detection failed for user {user_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit detection failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit detection failed: {e}") from e
|
||||
|
||||
def rank_habits_by_confidence_and_recency(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
confidence_weight: float = 0.6,
|
||||
recency_weight: float = 0.4
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Rank habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to rank
|
||||
confidence_weight: Weight for confidence score (0.0-1.0)
|
||||
recency_weight: Weight for recency score (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
List of habits ranked by combined score
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
|
||||
|
||||
def calculate_ranking_score(habit: BehaviorHabit) -> float:
|
||||
"""Calculate combined ranking score for a habit."""
|
||||
|
||||
# Confidence score (0.0-1.0) - convert from 0-100 scale
|
||||
confidence_score = habit.confidence_level / 100.0
|
||||
|
||||
# Recency score (0.0-1.0)
|
||||
current_time = datetime.now()
|
||||
days_since_last = (current_time - habit.last_observed).days
|
||||
|
||||
# Exponential decay for recency (habits lose relevance over time)
|
||||
if days_since_last <= 7:
|
||||
recency_score = 1.0 # Very recent
|
||||
elif days_since_last <= 30:
|
||||
recency_score = 0.8 # Recent
|
||||
elif days_since_last <= 90:
|
||||
recency_score = 0.5 # Somewhat recent
|
||||
elif days_since_last <= 180:
|
||||
recency_score = 0.3 # Old
|
||||
else:
|
||||
recency_score = 0.1 # Very old
|
||||
|
||||
# Frequency pattern bonus
|
||||
frequency_bonuses = {
|
||||
FrequencyPattern.DAILY: 0.2,
|
||||
FrequencyPattern.WEEKLY: 0.15,
|
||||
FrequencyPattern.MONTHLY: 0.1,
|
||||
FrequencyPattern.SEASONAL: 0.05,
|
||||
FrequencyPattern.OCCASIONAL: 0.0,
|
||||
FrequencyPattern.EVENT_TRIGGERED: 0.05
|
||||
}
|
||||
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
|
||||
|
||||
# Evidence quality bonus
|
||||
evidence_bonus = min(len(habit.specific_examples) / 10.0, 0.1) # Max 0.1 bonus
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 0.1 if habit.is_current else 0.0
|
||||
|
||||
# Calculate final score
|
||||
base_score = (confidence_score * confidence_weight +
|
||||
recency_score * recency_weight)
|
||||
|
||||
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
|
||||
|
||||
return min(final_score, 1.0) # Cap at 1.0
|
||||
|
||||
# Sort habits by ranking score (descending)
|
||||
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
|
||||
|
||||
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
|
||||
|
||||
return ranked_habits
|
||||
|
||||
def distinguish_current_vs_past_habits(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
current_threshold_days: int = 30
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Distinguish between current and past habits based on recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to categorize
|
||||
current_threshold_days: Days threshold for considering a habit current
|
||||
|
||||
Returns:
|
||||
List of habits with updated is_current status
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_date = current_time - timedelta(days=current_threshold_days)
|
||||
|
||||
current_habits = []
|
||||
past_habits = []
|
||||
|
||||
for habit in habits:
|
||||
# Update is_current status based on last observation
|
||||
if habit.last_observed >= cutoff_date:
|
||||
# Create updated habit with is_current = True
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=True
|
||||
)
|
||||
current_habits.append(updated_habit)
|
||||
else:
|
||||
# Create updated habit with is_current = False
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=False
|
||||
)
|
||||
past_habits.append(updated_habit)
|
||||
|
||||
# Return current habits first, then past habits
|
||||
categorized_habits = current_habits + past_habits
|
||||
|
||||
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
|
||||
|
||||
return categorized_habits
|
||||
@@ -1,321 +0,0 @@
|
||||
"""LLM Client Wrapper for Implicit Memory Analysis
|
||||
|
||||
This module provides a specialized LLM client wrapper that integrates with the
|
||||
MemoryClientFactory to perform implicit memory analysis tasks including preference
|
||||
extraction, personality dimension analysis, interest categorization, and habit detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.prompts import (
|
||||
get_dimension_analysis_prompt,
|
||||
get_habit_analysis_prompt,
|
||||
get_interest_analysis_prompt,
|
||||
get_preference_analysis_prompt,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.schemas.implicit_memory_schema import UserMemorySummary
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Response Models for LLM Analysis
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ImplicitMemoryLLMClient:
|
||||
"""LLM client wrapper for implicit memory analysis.
|
||||
|
||||
This class provides a high-level interface for performing LLM-based analysis
|
||||
of user memory summaries to extract preferences, personality dimensions,
|
||||
interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, default_model_id: Optional[str] = None):
|
||||
"""Initialize the LLM client wrapper.
|
||||
|
||||
Args:
|
||||
db: Database session for accessing model configurations
|
||||
default_model_id: Default LLM model ID to use if none specified
|
||||
"""
|
||||
self.db = db
|
||||
self.default_model_id = default_model_id
|
||||
self._client_factory = MemoryClientFactory(db)
|
||||
|
||||
logger.info("ImplicitMemoryLLMClient initialized")
|
||||
|
||||
def _get_llm_client(self, model_id: Optional[str] = None):
|
||||
"""Get LLM client instance.
|
||||
|
||||
Args:
|
||||
model_id: LLM model ID to use, defaults to default_model_id
|
||||
|
||||
Returns:
|
||||
LLM client instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no model ID is provided and no default is set
|
||||
LLMClientException: If client creation fails
|
||||
"""
|
||||
effective_model_id = model_id or self.default_model_id
|
||||
if not effective_model_id:
|
||||
raise ValueError("No LLM model ID provided and no default model ID set")
|
||||
|
||||
try:
|
||||
client = self._client_factory.get_llm_client(effective_model_id)
|
||||
logger.debug(f"Created LLM client for model: {effective_model_id}")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
|
||||
raise LLMClientException(f"Failed to create LLM client: {e}") from e
|
||||
|
||||
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
|
||||
"""Prepare user memory summaries for LLM analysis.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries
|
||||
|
||||
Returns:
|
||||
List of formatted summary dictionaries
|
||||
"""
|
||||
formatted_summaries = []
|
||||
for summary in user_summaries:
|
||||
formatted_summary = {
|
||||
'summary_id': summary.summary_id,
|
||||
'user_content': summary.user_content,
|
||||
'timestamp': summary.timestamp.isoformat(),
|
||||
'summary_type': summary.summary_type,
|
||||
'confidence_score': summary.confidence_score
|
||||
}
|
||||
formatted_summaries.append(formatted_summary)
|
||||
|
||||
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
|
||||
return formatted_summaries
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user preferences from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted preferences
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
|
||||
return {"preferences": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for preference analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=PreferenceAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user personality dimensions from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing dimension scores and analysis
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
|
||||
return {"dimensions": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for dimension analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DimensionAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
dimensions = result.get('dimensions', {})
|
||||
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user interest distribution from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing interest area distribution
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
|
||||
return {"interest_distribution": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for interest analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
interest_dist = result.get('interest_distribution', {})
|
||||
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user behavioral habits from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing identified behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
|
||||
return {"habits": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for habit analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=HabitAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
@@ -1,69 +0,0 @@
|
||||
"""LLM Prompt Templates for Implicit Memory Analysis
|
||||
|
||||
This module contains prompt rendering functions for analyzing user memory summaries
|
||||
to extract preferences, personality dimensions, interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
# Setup Jinja2 environment
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
def _render_template(template_name: str, **kwargs) -> str:
|
||||
"""Helper function to render Jinja2 templates."""
|
||||
template = prompt_env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
def get_preference_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted preference analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"preference_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_dimension_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted dimension analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"dimension_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_interest_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted interest analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"interest_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_habit_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted habit analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"habit_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Dimensions to Analyze
|
||||
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
|
||||
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
|
||||
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
|
||||
4. **Literature** (0-100%): Reading habits, writing style, literary references
|
||||
|
||||
## Instructions
|
||||
1. Analyze the user's content for each dimension
|
||||
2. Calculate percentage scores (0-100%)
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 0-100},
|
||||
"aesthetic": {"percentage": 0-100},
|
||||
"technology": {"percentage": 0-100},
|
||||
"literature": {"percentage": 0-100}
|
||||
}
|
||||
}
|
||||
|
||||
## Example
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 75},
|
||||
"aesthetic": {"percentage": 45},
|
||||
"technology": {"percentage": 60},
|
||||
"literature": {"percentage": 30}
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
You are an expert at identifying behavioral patterns and habits from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
|
||||
2. Focus on specific, concrete habits with temporal patterns
|
||||
3. For each habit, provide:
|
||||
- habit_description: Clear, specific description
|
||||
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
|
||||
- time_context: When it typically happens
|
||||
- confidence_level: "high", "medium", "low"
|
||||
- supporting_summaries: References to evidence
|
||||
- specific_examples: Concrete examples from summaries
|
||||
- is_current: true if current habit, false if past habit
|
||||
4. Only include habits with medium or high confidence
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "string",
|
||||
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
|
||||
"time_context": "string",
|
||||
"confidence_level": "high|medium|low",
|
||||
"supporting_summaries": ["id1", "id2"],
|
||||
"specific_examples": ["example1", "example2"],
|
||||
"is_current": true|false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "drinks coffee every morning",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "morning routine",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["needs coffee to start the day"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "每天早上喝咖啡",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "早晨日常",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["需要咖啡来开始一天"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
You are an expert at analyzing user interests from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Interest Categories
|
||||
1. **Tech**: Programming, technology, software tools, hardware
|
||||
2. **Lifestyle**: Daily routines, health, hobbies, social activities
|
||||
3. **Music**: Music preferences, instruments, concerts
|
||||
4. **Art**: Visual arts, creative projects, design, aesthetics
|
||||
|
||||
## Instructions
|
||||
1. Categorize the user's interests into the four areas
|
||||
2. Calculate percentage distribution (must total 100%)
|
||||
3. Provide specific evidence for each interest area
|
||||
4. Use "increasing", "decreasing", or "stable" for trending direction
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
You are an expert at analyzing user memory summaries to identify implicit preferences.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Focus ONLY on the specified user's preferences
|
||||
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
|
||||
3. DO NOT use long phrases - use short nouns or noun phrases
|
||||
4. Only include preferences with confidence_score >= 0.3
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"preferences": [
|
||||
{
|
||||
"tag_name": "short tag",
|
||||
"confidence_score": 0.0-1.0,
|
||||
"supporting_evidence": ["evidence1", "evidence2"],
|
||||
"context_details": "brief context",
|
||||
"category": "category or null"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
|
||||
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
|
||||
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
|
||||
]
|
||||
}
|
||||
391
api/app/core/memory/analytics/memory_insight.py
Normal file
391
api/app/core/memory/analytics/memory_insight.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
This module provides the MemoryInsight class for analyzing user memory data.
|
||||
|
||||
This script can be executed directly to generate a memory insight report for a test user.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
# To run this script directly, we need to add the src directory to the Python path
|
||||
# to resolve the inconsistent imports in other modules.
|
||||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
"""
|
||||
Represents the classification of a tag into a specific domain.
|
||||
"""
|
||||
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||
)
|
||||
|
||||
class InsightReport(BaseModel):
|
||||
"""
|
||||
Represents the final insight report generated by the LLM.
|
||||
"""
|
||||
|
||||
report: str = Field(
|
||||
...,
|
||||
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryInsight:
|
||||
"""
|
||||
Provides insights into user memories by analyzing various aspects of their data.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
async def get_domain_distribution(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculates the distribution of memory domains based on hot tags.
|
||||
"""
|
||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||
if not hot_tags:
|
||||
return {}
|
||||
|
||||
domain_counts = Counter()
|
||||
for tag, _ in hot_tags:
|
||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||
- 其他:确实无法归入以上任何类别的内容
|
||||
|
||||
标签: {tag}
|
||||
|
||||
分析步骤:
|
||||
1. 仔细理解标签的核心含义和使用场景
|
||||
2. 对比各个领域的关键词,找到最匹配的领域
|
||||
3. 特别注意:
|
||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
# 直接调用并等待结果
|
||||
classification = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=TagClassification,
|
||||
)
|
||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||
domain_counts[classification.domain] += 1
|
||||
|
||||
total_tags = sum(domain_counts.values())
|
||||
if total_tags == 0:
|
||||
return {}
|
||||
|
||||
domain_distribution = {
|
||||
domain: count / total_tags for domain, count in domain_counts.items()
|
||||
}
|
||||
return dict(
|
||||
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
|
||||
)
|
||||
|
||||
async def get_active_periods(self) -> list[int]:
|
||||
"""
|
||||
Identifies the top 2 most active months for the user.
|
||||
Only returns months if there is valid and diverse time data.
|
||||
|
||||
This method checks if the time data represents real user memory timestamps
|
||||
rather than auto-generated system timestamps by verifying:
|
||||
1. Time data exists and is parseable
|
||||
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
month_counts = Counter()
|
||||
valid_dates_count = 0
|
||||
for record in records:
|
||||
creation_time_str = record.get("creation_time")
|
||||
if not creation_time_str:
|
||||
continue
|
||||
try:
|
||||
# 尝试解析时间字符串
|
||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||
month_counts[dt_object.month] += 1
|
||||
valid_dates_count += 1
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
# 如果解析失败,跳过这条记录
|
||||
continue
|
||||
|
||||
# 如果没有有效的时间数据,返回空列表
|
||||
if not month_counts or valid_dates_count == 0:
|
||||
return []
|
||||
|
||||
# 检查时间分布是否过于集中(可能是批量导入的数据)
|
||||
# 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间
|
||||
unique_months = len(month_counts)
|
||||
if unique_months <= 2:
|
||||
# 只有1-2个月有数据,很可能是批量导入
|
||||
most_common_count = month_counts.most_common(1)[0][1]
|
||||
if most_common_count / valid_dates_count > 0.8:
|
||||
# 超过80%集中在一个月,认为是系统时间戳
|
||||
return []
|
||||
|
||||
# 如果时间分布较为分散(3个月以上),认为是真实时间数据
|
||||
if unique_months >= 3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 2个月的情况,检查是否分布均匀
|
||||
if unique_months == 2:
|
||||
counts = list(month_counts.values())
|
||||
# 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据
|
||||
ratio = min(counts) / max(counts)
|
||||
if ratio > 0.3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 其他情况返回空列表
|
||||
return []
|
||||
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""
|
||||
Finds the user with whom the most memories are shared.
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d1:Dialogue {{group_id: '{self.user_id}'}})<-[:MENTIONS]-(s:Statement)-[:MENTIONS]->(d2:Dialogue)
|
||||
WHERE d1 <> d2
|
||||
RETURN d2.group_id AS other_user_id, COUNT(s) AS common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
if not records:
|
||||
return None
|
||||
|
||||
most_connected_user = records[0]["other_user_id"]
|
||||
common_memories_count = records[0]["common_statements"]
|
||||
|
||||
time_range_query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id IN ['{self.user_id}', '{most_connected_user}']
|
||||
RETURN min(d.created_at) AS start_time, max(d.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(time_range_query)
|
||||
start_year, end_year = "N/A", "N/A"
|
||||
if time_records and time_records[0]["start_time"]:
|
||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||
|
||||
return {
|
||||
"user_id": most_connected_user,
|
||||
"common_memories_count": common_memories_count,
|
||||
"time_range": f"{start_year}-{end_year}",
|
||||
}
|
||||
|
||||
async def generate_insight_report(self) -> str:
|
||||
"""
|
||||
Generates the final insight report in natural language.
|
||||
"""
|
||||
domain_dist, active_periods, social_conn = await asyncio.gather(
|
||||
self.get_domain_distribution(),
|
||||
self.get_active_periods(),
|
||||
self.get_social_connections(),
|
||||
)
|
||||
|
||||
prompt_parts = []
|
||||
|
||||
if domain_dist:
|
||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
||||
prompt_parts.append(f"- 核心领域: 用户的记忆主要集中在 {top_domains}。")
|
||||
|
||||
if active_periods:
|
||||
months_str = " 和 ".join(map(str, active_periods))
|
||||
prompt_parts.append(f"- 活跃时段: 用户在每年的 {months_str} 月最为活跃。")
|
||||
|
||||
if social_conn:
|
||||
prompt_parts.append(
|
||||
f"- 社交关联: 与用户\"{social_conn['user_id']}\"拥有最多共同记忆({social_conn['common_memories_count']}条),时间范围主要在 {social_conn['time_range']}。"
|
||||
)
|
||||
|
||||
if not prompt_parts:
|
||||
return "暂无足够数据生成洞察报告。"
|
||||
|
||||
system_prompt = '''你是一位资深的个人记忆分析师。你的任务是根据我提供的要点,为用户生成一段简洁、自然、个性化的记忆洞察报告。
|
||||
|
||||
重要规则:
|
||||
1. 报告需要将所有要点流畅地串联成一个段落
|
||||
2. 语言风格要亲切、易于理解,就像和朋友聊天一样
|
||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
||||
4. 只使用我提供的要点,不要编造或推测任何信息
|
||||
5. 如果某个维度没有数据(如没有活跃时段信息),就不要在报告中提及该维度
|
||||
|
||||
例如,如果输入是:
|
||||
- 核心领域: 用户的记忆主要集中在 旅行(38%), 工作(24%), 家庭(21%)。
|
||||
- 活跃时段: 用户在每年的 4 和 10 月最为活跃。
|
||||
- 社交关联: 与用户"张明"拥有最多共同记忆(47条),时间范围主要在 2017-2020。
|
||||
|
||||
你的输出应该是:
|
||||
"您的记忆集中在旅行(38%)、工作(24%)和家庭(21%)三大领域。每年4月和10月是您最活跃的记录期,可能与春秋季旅行计划相关。您与'张明'共同拥有最多记忆(47条),主要集中在2017-2020年间。"
|
||||
|
||||
如果输入只有:
|
||||
- 核心领域: 用户的记忆主要集中在 教育(65%), 学习(25%)。
|
||||
|
||||
你的输出应该是:
|
||||
"您的记忆主要集中在教育(65%)和学习(25%)两大领域,显示出您对知识和成长的持续关注。"'''
|
||||
|
||||
user_prompt = "\n".join(prompt_parts)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.chat(messages=messages)
|
||||
|
||||
# 确保返回字符串类型
|
||||
content = response.content
|
||||
if isinstance(content, list):
|
||||
# 如果是列表格式(如 [{'type': 'text', 'text': '...'}]),提取文本
|
||||
if len(content) > 0:
|
||||
if isinstance(content[0], dict):
|
||||
# 尝试提取 'text' 字段
|
||||
text = content[0].get('text', content[0].get('content', str(content[0])))
|
||||
return str(text)
|
||||
else:
|
||||
return str(content[0])
|
||||
return ""
|
||||
elif isinstance(content, dict):
|
||||
# 如果是字典格式,提取 text 字段
|
||||
return str(content.get('text', content.get('content', str(content))))
|
||||
else:
|
||||
# 已经是字符串或其他类型,转为字符串
|
||||
return str(content) if content is not None else ""
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the database connection.
|
||||
"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从环境变量读取
|
||||
test_user_id = DEFAULT_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
insight = None
|
||||
try:
|
||||
insight = MemoryInsight(user_id=test_user_id)
|
||||
report = await insight.generate_insight_report()
|
||||
print("--- 记忆洞察报告 ---")
|
||||
print(report)
|
||||
print("---------------------")
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json,使用全局配置路径
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["memory_insight"] = {
|
||||
"group_id": test_user_id,
|
||||
"report": report
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> memory_insight")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成报告时出错: {e}")
|
||||
finally:
|
||||
if insight:
|
||||
await insight.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This setup allows running the async main function
|
||||
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
202
api/app/core/memory/analytics/user_summary.py
Normal file
202
api/app/core/memory/analytics/user_summary.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Generate a concise "关于我" style user summary using data from Neo4j
|
||||
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
|
||||
|
||||
Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Ensure absolute imports work whether executed directly or via module
|
||||
try:
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementRecord:
|
||||
statement: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
class UserSummary:
|
||||
"""Builds a textual user summary for a given user/group id."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]:
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
records: List[StatementRecord] = []
|
||||
for r in rows:
|
||||
try:
|
||||
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
|
||||
except Exception:
|
||||
continue
|
||||
return records
|
||||
|
||||
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
|
||||
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
|
||||
return await get_hot_memory_tags(self.user_id, limit=limit)
|
||||
|
||||
async def generate(self) -> str:
|
||||
"""Generate a Chinese '关于我' style summary using the LLM."""
|
||||
# 1) Collect context
|
||||
entities = await self._get_top_entities(limit=40)
|
||||
statements = await self._get_recent_statements(limit=100)
|
||||
|
||||
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
|
||||
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
|
||||
|
||||
# 2) Compose prompt
|
||||
system_prompt = (
|
||||
"你是一位中文信息压缩助手。请基于提供的实体与语句,"
|
||||
"生成非常简洁的用户摘要,禁止臆测或虚构。要求:\n"
|
||||
"- 3–4 句,总字数不超过 120;\n"
|
||||
"- 先交代身份/城市,其次长期兴趣或习惯,最后给一两项代表性经历;\n"
|
||||
"- 避免形容词堆砌与空话,不用项目符号,不分段;\n"
|
||||
"- 使用客观的第三人称描述,语气克制、中立。"
|
||||
)
|
||||
|
||||
user_content_parts = [
|
||||
f"用户ID: {self.user_id}",
|
||||
"核心实体与频次: " + (", ".join(entity_lines) if entity_lines else "(空)"),
|
||||
"代表性语句样本: " + (" | ".join(statement_samples) if statement_samples else "(空)"),
|
||||
]
|
||||
user_prompt = "\n".join(user_content_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# 3) Call LLM
|
||||
response = await self.llm.chat(messages=messages)
|
||||
|
||||
# 确保返回字符串类型
|
||||
content = response.content
|
||||
if isinstance(content, list):
|
||||
# 如果是列表格式(如 [{'type': 'text', 'text': '...'}]),提取文本
|
||||
if len(content) > 0:
|
||||
if isinstance(content[0], dict):
|
||||
# 尝试提取 'text' 字段
|
||||
text = content[0].get('text', content[0].get('content', str(content[0])))
|
||||
return str(text)
|
||||
else:
|
||||
return str(content[0])
|
||||
return ""
|
||||
elif isinstance(content, dict):
|
||||
# 如果是字典格式,提取 text 字段
|
||||
return str(content.get('text', content.get('content', str(content))))
|
||||
else:
|
||||
# 已经是字符串或其他类型,转为字符串
|
||||
return str(content) if content is not None else ""
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str:
|
||||
# 默认从环境变量读取
|
||||
effective_group_id = user_id or DEFAULT_GROUP_ID
|
||||
svc = UserSummary(effective_group_id)
|
||||
try:
|
||||
return await svc.generate()
|
||||
finally:
|
||||
await svc.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始生成用户摘要…")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
summary = asyncio.run(generate_user_summary())
|
||||
print("\n— 用户摘要 —\n")
|
||||
print(summary)
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": DEFAULT_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> user_summary")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成摘要失败: {e}")
|
||||
print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。")
|
||||
@@ -37,20 +37,12 @@ def parse_historical_datetime(v):
|
||||
此函数手动解析 ISO 8601 格式的日期字符串,支持1-4位年份
|
||||
|
||||
Args:
|
||||
v: 日期值(可以是 None、datetime 对象、Neo4j DateTime 对象或字符串)
|
||||
v: 日期值(可以是 None、datetime 对象或字符串)
|
||||
|
||||
Returns:
|
||||
datetime 对象或 None
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
if v is None or isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
@@ -236,13 +228,6 @@ class StatementNode(Node):
|
||||
chunk_embedding: Optional embedding vector for the parent chunk
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this statement
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
|
||||
access_history: List of ISO timestamp strings recording each access
|
||||
last_access_time: ISO timestamp of the most recent access
|
||||
access_count: Total number of times this node has been accessed
|
||||
"""
|
||||
# Core fields (ordered as requested)
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
@@ -284,33 +269,6 @@ class StatementNode(Node):
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), default 0.5"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -393,22 +351,11 @@ class ExtractedEntityNode(Node):
|
||||
fact_summary: Summary of facts about this entity
|
||||
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
|
||||
config_id: Configuration ID used to process this entity (integer or string)
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
|
||||
access_history: List of ISO timestamp strings recording each access
|
||||
last_access_time: ISO timestamp of the most recent access
|
||||
access_count: Total number of times this node has been accessed
|
||||
"""
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
@@ -418,39 +365,6 @@ class ExtractedEntityNode(Node):
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), default 0.5"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
@@ -484,68 +398,14 @@ class MemorySummaryNode(Node):
|
||||
dialog_id: ID of the parent dialog
|
||||
chunk_ids: List of chunk IDs used to generate this summary
|
||||
content: Summary text content
|
||||
name: Title/name of the memory summary (generated by LLM, used as title in API)
|
||||
memory_type: Type/category of the episodic memory (e.g., Conversation, Project/Work, Learning, Decision, Important Event)
|
||||
summary_embedding: Optional embedding vector for the summary
|
||||
metadata: Additional metadata for the summary
|
||||
config_id: Configuration ID used to process this summary
|
||||
original_statement_id: ID of the original statement that was merged (for ACT-R forgetting)
|
||||
original_entity_id: ID of the original entity that was merged (for ACT-R forgetting)
|
||||
merged_at: Timestamp when the nodes were merged
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes
|
||||
access_history: List of ISO timestamp strings recording each access (reset on creation)
|
||||
last_access_time: ISO timestamp of the most recent access (set to creation time)
|
||||
access_count: Total number of times this node has been accessed (reset to 1 on creation)
|
||||
"""
|
||||
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
|
||||
content: str = Field(..., description="Summary text content")
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
description="ID of the original statement that was merged (for traceability)"
|
||||
)
|
||||
original_entity_id: Optional[str] = Field(
|
||||
None,
|
||||
description="ID of the original entity that was merged (for traceability)"
|
||||
)
|
||||
merged_at: Optional[datetime] = Field(
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), inherited from merged nodes"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access (reset on creation)"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access (set to creation time)"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=1,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
@@ -38,20 +38,10 @@ class Entity(BaseModel):
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
example: str = Field(
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Alternative names for this entity (abbreviations, full names, translations, etc.)"
|
||||
)
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
class Triplet(BaseModel):
|
||||
|
||||
@@ -69,12 +69,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item.get(score_field)
|
||||
# 对于 activation_value,None 值保持为 None,不使用回退值
|
||||
# 这样可以区分有激活值和无激活值的节点
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -82,433 +76,205 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
|
||||
if len(scores) == 1:
|
||||
# Single score, set to 1.0
|
||||
for item in results:
|
||||
if score_field in item or score_field == "activation_value":
|
||||
item[f"normalized_{score_field}"] = None
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
|
||||
# Calculate mean and standard deviation (only for valid scores)
|
||||
mean_score = sum(valid_scores) / len(valid_scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores)
|
||||
# Calculate mean and standard deviation
|
||||
mean_score = sum(scores) / len(scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
std_dev = math.sqrt(variance)
|
||||
|
||||
if std_dev == 0:
|
||||
# All valid scores are the same, set them to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
# All scores are the same, set them to 1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
else:
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
# 保持 None,不进行归一化
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
# Calculate z-score
|
||||
z_score = (score - mean_score) / std_dev
|
||||
# Transform to positive range using sigmoid function
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item[score_field]
|
||||
# Handle None or non-numeric scores
|
||||
if score is None or not isinstance(score, (int, float)):
|
||||
score = 0.0
|
||||
# Calculate z-score
|
||||
z_score = (score - mean_score) / std_dev
|
||||
# Transform to positive range using sigmoid function
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
|
||||
# ============================================================================
|
||||
def rerank_hybrid_results(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
|
||||
# def rerank_hybrid_results(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined scores
|
||||
# """
|
||||
# reranked = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# # Create a combined pool of unique items
|
||||
# combined_items = {}
|
||||
#
|
||||
# # Add keyword results with BM25 scores
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
#
|
||||
# # Add or update with embedding results
|
||||
# for item in embedding_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# # Update existing item with embedding score
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# # New item from embedding search only
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate combined scores and rank
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
#
|
||||
# # Combined score: weighted average of normalized scores
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
#
|
||||
# # Keep original score for reference
|
||||
# if "score" not in item and bm25_score > 0:
|
||||
# item["score"] = bm25_score
|
||||
# elif "score" not in item and embedding_score > 0:
|
||||
# item["score"] = embedding_score
|
||||
#
|
||||
# # Sort by combined score and limit results
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
|
||||
# def rerank_with_forgetting_curve(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10,
|
||||
# forgetting_config: ForgettingEngineConfig | None = None,
|
||||
# now: datetime | None = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
|
||||
#
|
||||
# The forgetting curve reduces scores for older memories or weaker connections.
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
# forgetting_config: Configuration for the forgetting engine
|
||||
# now: Optional current time override for testing
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined and final scores (after forgetting)
|
||||
# """
|
||||
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
# now_dt = now or datetime.now()
|
||||
#
|
||||
# reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
#
|
||||
# # Combine two result sets by ID
|
||||
# for src_items, is_embedding in (
|
||||
# (keyword_items, False), (embedding_items, True)
|
||||
# ):
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if not item_id:
|
||||
# continue
|
||||
# existing = combined_items.get(item_id)
|
||||
# if not existing:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
# # Update normalized score from the right source
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate scores and apply forgetting weights
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
#
|
||||
# # Estimate time elapsed in days
|
||||
# dt = _parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
#
|
||||
# # Memory strength (currently set to default value)
|
||||
# memory_strength = 1.0
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
# )
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
#
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
Returns:
|
||||
Reranked results with combined scores
|
||||
"""
|
||||
reranked = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
def rerank_with_activation(
|
||||
# Normalize scores within each search type
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# Create a combined pool of unique items
|
||||
combined_items = {}
|
||||
|
||||
# Add keyword results with BM25 scores
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
|
||||
# Add or update with embedding results
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
# Update existing item with embedding score
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
# New item from embedding search only
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate combined scores and rank
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = item.get("bm25_score", 0)
|
||||
embedding_score = item.get("embedding_score", 0)
|
||||
|
||||
# Combined score: weighted average of normalized scores
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
item["combined_score"] = combined_score
|
||||
|
||||
# Keep original score for reference
|
||||
if "score" not in item and bm25_score > 0:
|
||||
item["score"] = bm25_score
|
||||
elif "score" not in item and embedding_score > 0:
|
||||
item["score"] = embedding_score
|
||||
|
||||
# Sort by combined score and limit results
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
|
||||
def rerank_with_forgetting_curve(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
|
||||
阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding,取 Top-(limit*3)
|
||||
阶段2: 在候选中按 activation_score 排序,取 Top-limit
|
||||
无激活值的节点用于补充不足
|
||||
|
||||
返回结果中的评分字段说明:
|
||||
- bm25_score: BM25 归一化分数
|
||||
- embedding_score: Embedding 归一化分数
|
||||
- content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding
|
||||
- activation_score: ACTR 激活值归一化分数
|
||||
- base_score: 第一阶段基础分数(等于 content_score)
|
||||
- final_score: 最终排序依据
|
||||
* 有激活值的节点:final_score = activation_score
|
||||
* 无激活值的节点:final_score = base_score
|
||||
|
||||
参数:
|
||||
keyword_results: BM25 检索结果
|
||||
embedding_results: 向量嵌入检索结果
|
||||
alpha: BM25 权重 (默认: 0.6)
|
||||
limit: 每类最大结果数
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
|
||||
The forgetting curve reduces scores for older memories or weaker connections.
|
||||
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
forgetting_config: Configuration for the forgetting engine
|
||||
now: Optional current time override for testing
|
||||
|
||||
Returns:
|
||||
Reranked results with combined and final scores (after forgetting)
|
||||
"""
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
|
||||
# Normalize scores within each search type
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# 步骤 2: 按 ID 合并结果
|
||||
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined_items:
|
||||
# 更新现有项的嵌入分数
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
# 仅来自嵌入搜索的新项
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
|
||||
# Combine two result sets by ID
|
||||
for src_items, is_embedding in (
|
||||
(keyword_items, False), (embedding_items, True)
|
||||
):
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
existing = combined_items.get(item_id)
|
||||
if not existing:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# Update normalized score from the right source
|
||||
if is_embedding:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
# 无激活值的节点不应用遗忘曲线,保持原始分数
|
||||
item["final_score"] = base_score
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate scores and apply forgetting weights
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# Estimate time elapsed in days
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
first_stage_sorted = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
if activation_score is not None and isinstance(activation_score, (int, float)):
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
# 无激活值的节点保持第一阶段的内容相关性排序
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
#
|
||||
# final_score 语义:反映节点在最终结果中的排序依据
|
||||
# - 有激活值的节点:final_score = activation_score(第二阶段排序依据)
|
||||
# - 无激活值的节点:final_score = base_score(保持内容相关性分数)
|
||||
for item in sorted_items:
|
||||
activation_score = item.get("activation_score")
|
||||
if activation_score is not None and isinstance(activation_score, (int, float)):
|
||||
# 有激活值:使用激活度作为最终分数
|
||||
item["final_score"] = activation_score
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# Memory strength (currently set to default value)
|
||||
memory_strength = 1.0
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
)
|
||||
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
|
||||
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
|
||||
final_score = combined_score * forgetting_weight
|
||||
item["combined_score"] = final_score
|
||||
|
||||
sorted_items = sorted(
|
||||
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
)[:limit]
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
@@ -794,7 +560,6 @@ async def run_hybrid_search(
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
@@ -842,7 +607,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
logger.info("Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
@@ -856,7 +621,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
logger.info("Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
@@ -872,13 +637,13 @@ async def run_hybrid_search(
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
logger.info(f"Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
@@ -895,7 +660,7 @@ async def run_hybrid_search(
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
results = keyword_results
|
||||
else:
|
||||
@@ -905,7 +670,7 @@ async def run_hybrid_search(
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
results = embedding_results
|
||||
else:
|
||||
@@ -920,37 +685,33 @@ async def run_hybrid_search(
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
# Apply reranking (optionally with forgetting curve)
|
||||
rerank_start = time.time()
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha,
|
||||
limit=limit,
|
||||
forgetting_config=forgetting_cfg,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
reranked_results = rerank_with_forgetting_curve(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha,
|
||||
limit=limit,
|
||||
forgetting_config=forgetting_cfg,
|
||||
)
|
||||
else:
|
||||
reranked_results = rerank_hybrid_results(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
|
||||
limit=limit
|
||||
)
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
@@ -976,7 +737,6 @@ async def run_hybrid_search(
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
"reranking_alpha": rerank_alpha,
|
||||
"activation_boost_factor": activation_boost_factor,
|
||||
"forgetting_rerank": use_forgetting_rerank,
|
||||
"llm_rerank": llm_rerank_applied,
|
||||
}
|
||||
@@ -991,10 +751,8 @@ async def run_hybrid_search(
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
logger.info(f"Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"Latency breakdown: {latency_metrics}")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
@@ -42,6 +42,7 @@ from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
embedding_generation_all,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
|
||||
@@ -178,7 +179,7 @@ class ExtractionOrchestrator:
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
len(all_statements_list)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
@@ -200,9 +201,9 @@ class ExtractionOrchestrator:
|
||||
all_entities_list.extend(triplet_info.entities)
|
||||
all_triplets_list.extend(triplet_info.triplets)
|
||||
|
||||
len(all_entities_list)
|
||||
len(all_triplets_list)
|
||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
total_entities = len(all_entities_list)
|
||||
total_triplets = len(all_triplets_list)
|
||||
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
|
||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
@@ -384,7 +385,7 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于跟踪已完成的陈述句数量
|
||||
completed_statements = 0
|
||||
len(all_statements)
|
||||
total_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
@@ -496,7 +497,7 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于跟踪已完成的时间提取数量
|
||||
completed_temporal = 0
|
||||
len(all_statements)
|
||||
total_temporal_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
@@ -1081,12 +1082,10 @@ class ExtractionOrchestrator:
|
||||
statement_id=statement.id, # 添加必需的 statement_id 字段
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
@@ -29,118 +28,6 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
)
|
||||
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
|
||||
此方法应该在创建MemorySummary节点时调用,生成title和type
|
||||
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
"conversation", # 对话
|
||||
"project_work", # 项目/工作
|
||||
"learning", # 学习
|
||||
"decision", # 决策
|
||||
"important_event" # 重要事件
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
# 3. 解析LLM响应
|
||||
content_response = response.content
|
||||
if isinstance(content_response, list):
|
||||
if len(content_response) > 0:
|
||||
if isinstance(content_response[0], dict):
|
||||
text = content_response[0].get('text', content_response[0].get('content', str(content_response[0])))
|
||||
full_response = str(text)
|
||||
else:
|
||||
full_response = str(content_response[0])
|
||||
else:
|
||||
full_response = ""
|
||||
elif isinstance(content_response, dict):
|
||||
full_response = str(content_response.get('text', content_response.get('content', str(content_response))))
|
||||
else:
|
||||
full_response = str(content_response) if content_response is not None else ""
|
||||
|
||||
# 4. 解析JSON响应
|
||||
try:
|
||||
# 尝试从响应中提取JSON
|
||||
# 移除可能的markdown代码块标记
|
||||
json_str = full_response.strip()
|
||||
if json_str.startswith("```json"):
|
||||
json_str = json_str[7:]
|
||||
if json_str.startswith("```"):
|
||||
json_str = json_str[3:]
|
||||
if json_str.endswith("```"):
|
||||
json_str = json_str[:-3]
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
# 将类型转换为小写并去除空格
|
||||
episodic_type_normalized = str(episodic_type_raw).lower().strip()
|
||||
|
||||
# 检查是否在有效类型集合中
|
||||
if episodic_type_normalized in VALID_TYPES:
|
||||
episodic_type = episodic_type_normalized
|
||||
else:
|
||||
# 尝试映射常见的中文类型到英文
|
||||
type_mapping = {
|
||||
"对话": "conversation",
|
||||
"项目": "project_work",
|
||||
"工作": "project_work",
|
||||
"项目/工作": "project_work",
|
||||
"学习": "learning",
|
||||
"决策": "decision",
|
||||
"重要事件": "important_event",
|
||||
"事件": "important_event"
|
||||
}
|
||||
episodic_type = type_mapping.get(episodic_type_raw, DEFAULT_TYPE)
|
||||
logger.warning(
|
||||
f"LLM返回的类型 '{episodic_type_raw}' 不在有效集合中,"
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
chunk,
|
||||
@@ -172,27 +59,13 @@ async def _process_chunk_summary(
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
try:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
|
||||
# Embed the summary
|
||||
embedding = (await embedder.response([summary_text]))[0]
|
||||
|
||||
# Build node per chunk
|
||||
# Note: title is stored in the 'name' field, type is stored in 'memory_type' field
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||
name=f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
@@ -202,7 +75,6 @@ async def _process_chunk_summary(
|
||||
dialog_id=dialog.id,
|
||||
chunk_ids=[chunk.id],
|
||||
content=summary_text,
|
||||
memory_type=episodic_type,
|
||||
summary_embedding=embedding,
|
||||
metadata={"ref_id": dialog.ref_id},
|
||||
config_id=dialog.config_id, # 添加 config_id
|
||||
|
||||
@@ -1,40 +1,8 @@
|
||||
"""遗忘引擎模块
|
||||
|
||||
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论。
|
||||
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
calculate_activation,
|
||||
generate_forgetting_curve
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
ConsistencyCheckResult
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import (
|
||||
ForgettingStrategy
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import (
|
||||
ForgettingScheduler
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
calculate_forgetting_rate,
|
||||
load_actr_config_from_db,
|
||||
create_actr_calculator_from_config
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ForgettingEngine",
|
||||
"ACTRCalculator",
|
||||
"calculate_activation",
|
||||
"generate_forgetting_curve",
|
||||
"AccessHistoryManager",
|
||||
"ConsistencyCheckResult",
|
||||
"ForgettingStrategy",
|
||||
"ForgettingScheduler",
|
||||
"calculate_forgetting_rate",
|
||||
"load_actr_config_from_db",
|
||||
"create_actr_calculator_from_config"
|
||||
]
|
||||
__all__ = ["ForgettingEngine"]
|
||||
|
||||
@@ -1,732 +0,0 @@
|
||||
"""
|
||||
访问历史管理器模块
|
||||
|
||||
本模块实现访问历史的追踪、更新和一致性保证。
|
||||
负责在知识节点被访问时原子性地更新激活值相关的所有字段。
|
||||
|
||||
Classes:
|
||||
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConsistencyCheckResult(Enum):
|
||||
"""一致性检查结果枚举"""
|
||||
CONSISTENT = "consistent" # 数据一致
|
||||
INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time
|
||||
INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count
|
||||
MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值
|
||||
INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围
|
||||
|
||||
|
||||
class AccessHistoryManager:
|
||||
"""
|
||||
访问历史管理器
|
||||
|
||||
负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段:
|
||||
- activation_value: 激活值
|
||||
- access_history: 访问历史时间戳数组
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
|
||||
特性:
|
||||
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
|
||||
- 并发安全:使用乐观锁机制防止并发冲突
|
||||
- 一致性保证:提供一致性检查和自动修复功能
|
||||
- 智能修剪:自动修剪过长的访问历史
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
max_retries: int = 3
|
||||
):
|
||||
"""
|
||||
初始化访问历史管理器
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数(默认3次)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.max_retries = max_retries
|
||||
|
||||
async def record_access(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
记录节点访问并原子性更新所有相关字段
|
||||
|
||||
这是核心方法,实现了:
|
||||
1. 首次访问:初始化access_history,计算初始激活值
|
||||
2. 后续访问:追加访问历史,重新计算激活值
|
||||
3. 历史修剪:当历史过长时自动修剪
|
||||
4. 原子性:所有字段在单个事务中更新
|
||||
5. 并发安全:使用乐观锁重试机制
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据,包含:
|
||||
- id: 节点ID
|
||||
- activation_value: 更新后的激活值
|
||||
- access_history: 更新后的访问历史
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
- importance_score: 重要性分数
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点不存在或节点标签无效
|
||||
RuntimeError: 如果重试次数耗尽仍然失败
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
current_time_iso = current_time.isoformat()
|
||||
|
||||
# 验证节点标签
|
||||
valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"]
|
||||
if node_label not in valid_labels:
|
||||
raise ValueError(
|
||||
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
|
||||
)
|
||||
|
||||
# 使用乐观锁重试机制处理并发冲突
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso
|
||||
)
|
||||
|
||||
# 步骤3:原子性更新节点(使用事务)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
async def record_batch_access(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量记录多个节点的访问
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史。
|
||||
每个节点独立更新,失败的节点不影响其他节点。
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
group_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
import time
|
||||
batch_start = time.time()
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
for node_id in node_ids:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def check_consistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
|
||||
验证以下一致性规则:
|
||||
1. access_history[-1] == last_access_time
|
||||
2. len(access_history) == access_count
|
||||
3. 如果有访问历史,必须有激活值
|
||||
4. 激活值必须在有效范围内 [offset, 1.0]
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
access_history = node_data.get('access_history') or []
|
||||
last_access_time = node_data.get('last_access_time')
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
|
||||
# 检查1:access_history[-1] == last_access_time
|
||||
if access_history and last_access_time:
|
||||
if access_history[-1] != last_access_time:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME,
|
||||
f"access_history[-1]={access_history[-1]} != "
|
||||
f"last_access_time={last_access_time}"
|
||||
)
|
||||
|
||||
# 检查2:len(access_history) == access_count
|
||||
if len(access_history) != access_count:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
|
||||
f"len(access_history)={len(access_history)} != "
|
||||
f"access_count={access_count}"
|
||||
)
|
||||
|
||||
# 检查3:有访问历史必须有激活值
|
||||
if access_history and activation_value is None:
|
||||
return (
|
||||
ConsistencyCheckResult.MISSING_ACTIVATION,
|
||||
"Node has access_history but activation_value is None"
|
||||
)
|
||||
|
||||
# 检查4:激活值范围
|
||||
if activation_value is not None:
|
||||
offset = self.actr_calculator.offset
|
||||
if not (offset <= activation_value <= 1.0):
|
||||
return (
|
||||
ConsistencyCheckResult.INVALID_ACTIVATION_RANGE,
|
||||
f"activation_value={activation_value} out of range "
|
||||
f"[{offset}, 1.0]"
|
||||
)
|
||||
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检查多个节点的一致性
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一致性检查报告,包含:
|
||||
- total_checked: 检查的节点总数
|
||||
- consistent_count: 一致的节点数
|
||||
- inconsistent_count: 不一致的节点数
|
||||
- inconsistencies: 不一致节点的详细信息列表
|
||||
- consistency_rate: 一致性率(0-1)
|
||||
"""
|
||||
# 查询所有相关节点
|
||||
query = f"""
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
"""
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
query += """
|
||||
RETURN n.id as id
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
|
||||
# 检查每个节点
|
||||
inconsistencies = []
|
||||
consistent_count = 0
|
||||
|
||||
for node_id in node_ids:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
consistent_count += 1
|
||||
else:
|
||||
inconsistencies.append({
|
||||
'node_id': node_id,
|
||||
'result': result.value,
|
||||
'message': message
|
||||
})
|
||||
|
||||
total_checked = len(node_ids)
|
||||
inconsistent_count = len(inconsistencies)
|
||||
consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0
|
||||
|
||||
report = {
|
||||
'total_checked': total_checked,
|
||||
'consistent_count': consistent_count,
|
||||
'inconsistent_count': inconsistent_count,
|
||||
'inconsistencies': inconsistencies,
|
||||
'consistency_rate': consistency_rate
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"一致性检查完成: {node_label}, "
|
||||
f"一致率={consistency_rate:.2%}, "
|
||||
f"不一致节点={inconsistent_count}/{total_checked}"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
async def repair_inconsistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
|
||||
修复策略:
|
||||
1. 如果access_history[-1] != last_access_time:使用access_history[-1]
|
||||
2. 如果len(access_history) != access_count:使用len(access_history)
|
||||
3. 如果有历史但无激活值:重新计算激活值
|
||||
4. 如果激活值超出范围:重新计算激活值
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 检查一致性
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
repair_data = {}
|
||||
|
||||
# 修复last_access_time
|
||||
if access_history:
|
||||
repair_data['last_access_time'] = access_history[-1]
|
||||
|
||||
# 修复access_count
|
||||
repair_data['access_count'] = len(access_history)
|
||||
|
||||
# 修复activation_value
|
||||
if access_history:
|
||||
current_time = datetime.now()
|
||||
last_access_dt = datetime.fromisoformat(access_history[-1])
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in access_history
|
||||
]
|
||||
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=access_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_dt,
|
||||
importance_score=importance_score
|
||||
)
|
||||
repair_data['activation_value'] = activation_value
|
||||
|
||||
# 执行修复
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
query += """
|
||||
SET n += $repair_data
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
params = {
|
||||
'node_id': node_id,
|
||||
'repair_data': repair_data
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
await self.connector.execute_query(query, **params)
|
||||
|
||||
logger.info(
|
||||
f"成功修复节点不一致: {node_label}[{node_id}], "
|
||||
f"问题类型={result.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _fetch_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
query += """
|
||||
RETURN n.id as id,
|
||||
n.importance_score as importance_score,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count
|
||||
"""
|
||||
|
||||
params = {'node_id': node_id}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
async def _calculate_update(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
current_time: datetime,
|
||||
current_time_iso: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算更新数据
|
||||
|
||||
Args:
|
||||
node_data: 当前节点数据
|
||||
current_time: 当前时间(datetime对象)
|
||||
current_time_iso: 当前时间(ISO格式字符串)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
|
||||
# 修剪访问历史(如果过长)
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in new_access_history
|
||||
]
|
||||
trimmed_history_dt = self.actr_calculator.trim_access_history(
|
||||
access_history=access_history_dt,
|
||||
current_time=current_time
|
||||
)
|
||||
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
|
||||
|
||||
# 计算新的激活值
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=trimmed_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=current_time, # 最后访问时间就是当前时间
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
# 返回所有需要更新的字段
|
||||
return {
|
||||
'activation_value': activation_value,
|
||||
'access_history': trimmed_history,
|
||||
'last_access_time': current_time_iso,
|
||||
'access_count': len(trimmed_history)
|
||||
}
|
||||
|
||||
async def _atomic_update(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
update_data: Dict[str, Any],
|
||||
group_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
|
||||
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
|
||||
实现乐观锁机制防止并发冲突。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
update_data: 更新数据
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
read_query += " WHERE n.group_id = $group_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score
|
||||
"""
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if group_id:
|
||||
read_params['group_id'] = group_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
|
||||
if not current_node:
|
||||
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
|
||||
|
||||
# 获取当前版本号(如果不存在则为0)
|
||||
current_version = current_node.get('version', 0) or 0
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
'node_id': node_id,
|
||||
'current_version': current_version,
|
||||
'new_version': new_version,
|
||||
'activation_value': update_data['activation_value'],
|
||||
'access_history': update_data['access_history'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if group_id:
|
||||
update_params['group_id'] = group_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
|
||||
if not updated_node:
|
||||
raise RuntimeError(
|
||||
f"Version conflict detected for {node_label}[{node_id}]. "
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
result = await self.connector.execute_write_transaction(
|
||||
update_transaction,
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to atomically update node: {str(e)}"
|
||||
) from e
|
||||
@@ -1,359 +0,0 @@
|
||||
"""
|
||||
ACT-R Memory Activation Calculator
|
||||
|
||||
This module implements the unified Memory Activation model based on ACT-R
|
||||
(Adaptive Control of Thought-Rational) cognitive architecture theory.
|
||||
|
||||
The calculator integrates BLA (Base-Level Activation) computation into the
|
||||
Memory Activation formula, providing a single coherent model for memory strength
|
||||
calculation that reflects both recency and frequency of access.
|
||||
|
||||
Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
|
||||
|
||||
Where:
|
||||
- R(i): Memory activation value (0 to 1)
|
||||
- offset: Minimum retention rate (prevents complete forgetting)
|
||||
- λ: Forgetting rate (lambda_time / lambda_mem)
|
||||
- t: Time since last access
|
||||
- I: Importance score (0 to 1)
|
||||
- t_k: Time since k-th access
|
||||
- d: Decay constant (typically 0.5)
|
||||
|
||||
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class ACTRCalculator:
|
||||
"""
|
||||
Unified ACT-R Memory Activation Calculator.
|
||||
|
||||
This calculator implements the Memory Activation model that combines
|
||||
recency and frequency effects into a single activation value computation.
|
||||
It replaces the separate BLA calculation with an integrated approach.
|
||||
|
||||
Attributes:
|
||||
decay_constant: Decay parameter d (typically 0.5)
|
||||
forgetting_rate: Lambda parameter λ controlling forgetting speed
|
||||
offset: Minimum retention rate (baseline memory strength)
|
||||
max_history_length: Maximum number of access records to keep
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1,
|
||||
max_history_length: int = 100
|
||||
):
|
||||
"""
|
||||
Initialize the ACT-R calculator.
|
||||
|
||||
Args:
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
max_history_length: Maximum access history length (default 100)
|
||||
"""
|
||||
self.decay_constant = decay_constant
|
||||
self.forgetting_rate = forgetting_rate
|
||||
self.offset = offset
|
||||
self.max_history_length = max_history_length
|
||||
|
||||
def calculate_memory_activation(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate memory activation value using the unified Memory Activation formula.
|
||||
|
||||
This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
|
||||
|
||||
The formula integrates:
|
||||
- Recency effect: Recent accesses contribute more (via t)
|
||||
- Frequency effect: Multiple accesses strengthen memory (via Σ)
|
||||
- Importance weighting: Important memories decay slower (via I)
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps (ISO format or datetime objects)
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
|
||||
Returns:
|
||||
float: Memory activation value between offset and 1.0
|
||||
|
||||
Raises:
|
||||
ValueError: If access_history is empty or contains invalid data
|
||||
"""
|
||||
if not access_history:
|
||||
raise ValueError("access_history cannot be empty")
|
||||
|
||||
if not (0.0 <= importance_score <= 1.0):
|
||||
raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}")
|
||||
|
||||
# Calculate time since last access (in days)
|
||||
time_since_last = (current_time - last_access_time).total_seconds() / 86400.0
|
||||
time_since_last = max(time_since_last, 0.0001) # Avoid division by zero
|
||||
|
||||
# Calculate BLA component: Σ(I·t_k^(-d))
|
||||
bla_sum = 0.0
|
||||
for access_time in access_history:
|
||||
# Calculate time since this access (in days)
|
||||
time_diff = (current_time - access_time).total_seconds() / 86400.0
|
||||
time_diff = max(time_diff, 0.0001) # Avoid division by zero
|
||||
|
||||
# Add weighted power-law term: I * t_k^(-d)
|
||||
bla_sum += importance_score * (time_diff ** (-self.decay_constant))
|
||||
|
||||
# Avoid division by zero in case of numerical issues
|
||||
if bla_sum <= 0:
|
||||
bla_sum = 0.0001
|
||||
|
||||
# Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA)
|
||||
exponent = -self.forgetting_rate * time_since_last / bla_sum
|
||||
|
||||
# Clamp exponent to avoid numerical overflow/underflow
|
||||
exponent = max(min(exponent, 100), -100)
|
||||
|
||||
activation = self.offset + (1 - self.offset) * math.exp(exponent)
|
||||
|
||||
# Ensure activation is within valid range [offset, 1.0]
|
||||
return max(self.offset, min(1.0, activation))
|
||||
|
||||
def trim_access_history(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime
|
||||
) -> List[datetime]:
|
||||
"""
|
||||
Intelligently trim access history to prevent unbounded growth.
|
||||
|
||||
Strategy:
|
||||
- Keep all records if under max_history_length
|
||||
- If over limit, keep most recent 50% and sample from older records
|
||||
- Preserves both recent accesses (high importance) and historical pattern
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps (sorted or unsorted)
|
||||
current_time: Current time for calculation
|
||||
|
||||
Returns:
|
||||
List[datetime]: Trimmed access history
|
||||
"""
|
||||
if len(access_history) <= self.max_history_length:
|
||||
return access_history
|
||||
|
||||
# Sort by time (most recent first)
|
||||
sorted_history = sorted(access_history, reverse=True)
|
||||
|
||||
# Calculate split point (keep most recent 50%)
|
||||
keep_recent_count = self.max_history_length // 2
|
||||
|
||||
# Keep most recent 50%
|
||||
recent_records = sorted_history[:keep_recent_count]
|
||||
|
||||
# Sample from older records
|
||||
older_records = sorted_history[keep_recent_count:]
|
||||
sample_count = self.max_history_length - keep_recent_count
|
||||
|
||||
if len(older_records) <= sample_count:
|
||||
# If older records fit, keep them all
|
||||
sampled_older = older_records
|
||||
else:
|
||||
# Sample evenly from older records
|
||||
step = len(older_records) / sample_count
|
||||
sampled_older = [
|
||||
older_records[int(i * step)]
|
||||
for i in range(sample_count)
|
||||
]
|
||||
|
||||
# Combine and return
|
||||
trimmed_history = recent_records + sampled_older
|
||||
return sorted(trimmed_history, reverse=True)
|
||||
|
||||
def get_forgetting_curve( # 预测激活值,决定复习;测试不同配置效果,选择合适的d
|
||||
self,
|
||||
initial_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
days: int = 60
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate forgetting curve data for visualization.
|
||||
|
||||
This method simulates how memory activation decays over time
|
||||
for a single initial access, useful for understanding and
|
||||
visualizing the forgetting behavior.
|
||||
|
||||
Args:
|
||||
initial_time: Time of initial memory creation/access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
days: Number of days to simulate (default 60)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with keys:
|
||||
- 'day': Day number (0 to days)
|
||||
- 'activation': Memory activation value
|
||||
- 'retention_rate': Same as activation (for compatibility)
|
||||
"""
|
||||
curve_data = []
|
||||
access_history = [initial_time]
|
||||
|
||||
for day in range(days + 1):
|
||||
current_time = initial_time + timedelta(days=day)
|
||||
|
||||
try:
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=initial_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
except ValueError:
|
||||
# Handle edge cases
|
||||
activation = self.offset
|
||||
|
||||
curve_data.append({
|
||||
'day': day,
|
||||
'activation': activation,
|
||||
'retention_rate': activation # Alias for compatibility
|
||||
})
|
||||
|
||||
return curve_data
|
||||
|
||||
def calculate_forgetting_score(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate forgetting score (inverse of activation).
|
||||
|
||||
Forgetting score = 1 - activation value
|
||||
Higher score means more likely to be forgotten.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
|
||||
Returns:
|
||||
float: Forgetting score between 0 and (1 - offset)
|
||||
"""
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
return 1.0 - activation
|
||||
|
||||
def should_forget(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
threshold: float = 0.3
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a memory should be forgotten based on activation threshold.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
threshold: Activation threshold below which memory should be forgotten
|
||||
|
||||
Returns:
|
||||
bool: True if activation < threshold (should forget), False otherwise
|
||||
"""
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
return activation < threshold
|
||||
|
||||
|
||||
# Convenience functions for quick calculations
|
||||
def calculate_activation(
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1
|
||||
) -> float:
|
||||
"""
|
||||
Quick function to calculate activation without creating a calculator instance.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
|
||||
Returns:
|
||||
float: Memory activation value between offset and 1.0
|
||||
"""
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=decay_constant,
|
||||
forgetting_rate=forgetting_rate,
|
||||
offset=offset
|
||||
)
|
||||
return calculator.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
|
||||
def generate_forgetting_curve(
|
||||
initial_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
days: int = 60,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Quick function to generate forgetting curve data.
|
||||
|
||||
Args:
|
||||
initial_time: Time of initial memory creation/access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
days: Number of days to simulate (default 60)
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with forgetting curve data
|
||||
"""
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=decay_constant,
|
||||
forgetting_rate=forgetting_rate,
|
||||
offset=offset
|
||||
)
|
||||
return calculator.get_forgetting_curve(
|
||||
initial_time=initial_time,
|
||||
importance_score=importance_score,
|
||||
days=days
|
||||
)
|
||||
@@ -1,195 +0,0 @@
|
||||
"""
|
||||
遗忘引擎配置工具模块
|
||||
|
||||
本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。
|
||||
|
||||
Functions:
|
||||
calculate_forgetting_rate: 计算遗忘速率(lambda_time / lambda_mem)
|
||||
load_actr_config_from_db: 从数据库加载 ACT-R 配置参数
|
||||
create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
||||
"""
|
||||
计算遗忘速率
|
||||
|
||||
公式:forgetting_rate = lambda_time / lambda_mem
|
||||
|
||||
这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数,
|
||||
用于 ACT-R 激活值计算。
|
||||
|
||||
Args:
|
||||
lambda_time: 时间衰减参数(0-1)
|
||||
lambda_mem: 记忆衰减参数(0-1)
|
||||
|
||||
Returns:
|
||||
float: 遗忘速率
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 lambda_mem 为 0
|
||||
|
||||
Examples:
|
||||
>>> calculate_forgetting_rate(0.5, 0.5)
|
||||
1.0
|
||||
>>> calculate_forgetting_rate(0.3, 0.5)
|
||||
0.6
|
||||
"""
|
||||
if lambda_mem == 0:
|
||||
raise ValueError("lambda_mem 不能为 0")
|
||||
|
||||
forgetting_rate = lambda_time / lambda_mem
|
||||
|
||||
logger.debug(
|
||||
f"计算遗忘速率: lambda_time={lambda_time}, "
|
||||
f"lambda_mem={lambda_mem}, "
|
||||
f"forgetting_rate={forgetting_rate:.4f}"
|
||||
)
|
||||
|
||||
return forgetting_rate
|
||||
|
||||
|
||||
def load_actr_config_from_db(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库加载 ACT-R 配置参数
|
||||
|
||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
||||
并计算派生参数(如 forgetting_rate)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置 ID(可选,如果为 None 则使用默认值)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 配置参数字典,包含:
|
||||
- decay_constant: 衰减常数 d
|
||||
- lambda_time: 时间衰减参数
|
||||
- lambda_mem: 记忆衰减参数
|
||||
- forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出)
|
||||
- offset: 偏移量
|
||||
- max_history_length: 访问历史最大长度
|
||||
- forgetting_threshold: 遗忘阈值
|
||||
- min_days_since_access: 最小未访问天数
|
||||
- enable_llm_summary: 是否使用 LLM 生成摘要
|
||||
- max_merge_batch_size: 单次最大融合节点对数
|
||||
- forgetting_interval_hours: 遗忘周期间隔
|
||||
|
||||
注意:llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取
|
||||
|
||||
Raises:
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
"""
|
||||
# 必须指定 config_id
|
||||
if config_id is None:
|
||||
logger.error("未指定 config_id,无法加载配置")
|
||||
raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID")
|
||||
|
||||
# 从数据库加载配置
|
||||
try:
|
||||
repository = DataConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None:
|
||||
logger.error(f"配置不存在: config_id={config_id}")
|
||||
raise ValueError(f"配置不存在: config_id={config_id}")
|
||||
|
||||
# 读取配置参数(信任数据库默认值)
|
||||
lambda_time = db_config.lambda_time
|
||||
lambda_mem = db_config.lambda_mem
|
||||
decay_constant = db_config.decay_constant
|
||||
offset = db_config.offset
|
||||
max_history_length = db_config.max_history_length
|
||||
forgetting_threshold = db_config.forgetting_threshold
|
||||
min_days_since_access = db_config.min_days_since_access
|
||||
enable_llm_summary = db_config.enable_llm_summary
|
||||
max_merge_batch_size = db_config.max_merge_batch_size
|
||||
forgetting_interval_hours = db_config.forgetting_interval_hours
|
||||
|
||||
# 计算 forgetting_rate
|
||||
forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem)
|
||||
|
||||
config = {
|
||||
'decay_constant': decay_constant,
|
||||
'lambda_time': lambda_time,
|
||||
'lambda_mem': lambda_mem,
|
||||
'forgetting_rate': forgetting_rate,
|
||||
'offset': offset,
|
||||
'max_history_length': max_history_length,
|
||||
'forgetting_threshold': forgetting_threshold,
|
||||
'min_days_since_access': min_days_since_access,
|
||||
'enable_llm_summary': enable_llm_summary,
|
||||
'max_merge_batch_size': max_merge_batch_size,
|
||||
'forgetting_interval_hours': forgetting_interval_hours
|
||||
# 注意:llm_id 不包含在配置响应中,仅在内部使用
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"成功加载 ACT-R 配置: config_id={config_id}, "
|
||||
f"forgetting_rate={forgetting_rate:.4f}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_actr_calculator_from_config(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
) -> ACTRCalculator:
|
||||
"""
|
||||
从数据库配置创建 ACTRCalculator 实例
|
||||
|
||||
这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置 ID(可选,如果为 None 则使用默认值)
|
||||
|
||||
Returns:
|
||||
ACTRCalculator: 配置好的 ACT-R 计算器实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
|
||||
Examples:
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> db = Session()
|
||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
||||
>>> # 使用计算器
|
||||
>>> activation = calculator.calculate_memory_activation(...)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
# 创建计算器
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=config['decay_constant'],
|
||||
forgetting_rate=config['forgetting_rate'],
|
||||
offset=config['offset'],
|
||||
max_history_length=config['max_history_length']
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"创建 ACTRCalculator: config_id={config_id}, "
|
||||
f"decay_constant={config['decay_constant']}, "
|
||||
f"forgetting_rate={config['forgetting_rate']:.4f}, "
|
||||
f"offset={config['offset']}"
|
||||
)
|
||||
|
||||
return calculator
|
||||
@@ -1,351 +0,0 @@
|
||||
"""
|
||||
遗忘调度器模块
|
||||
|
||||
本模块实现遗忘周期的调度和管理,负责:
|
||||
1. 手动触发遗忘周期
|
||||
2. 批量处理可遗忘节点(限制批量大小)
|
||||
3. 按激活值优先级排序(激活值最低的优先)
|
||||
4. 进度跟踪和日志记录
|
||||
5. 生成遗忘报告
|
||||
|
||||
注意:定期调度功能已迁移到 Celery Beat,见 app/tasks.py 中的 run_forgetting_cycle_task
|
||||
|
||||
Classes:
|
||||
ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForgettingScheduler:
|
||||
"""
|
||||
遗忘调度器
|
||||
|
||||
管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。
|
||||
|
||||
核心功能:
|
||||
1. 运行遗忘周期:识别可遗忘节点并批量融合
|
||||
2. 优先级排序:优先处理激活值最低的节点对
|
||||
3. 批量限制:限制单次处理的节点对数量
|
||||
4. 进度跟踪:每完成 10% 记录一次日志
|
||||
5. 遗忘报告:生成详细的执行报告
|
||||
|
||||
注意:定期调度功能已迁移到 Celery Beat 定时任务
|
||||
|
||||
Attributes:
|
||||
forgetting_strategy: 遗忘策略执行器实例
|
||||
connector: Neo4j 连接器实例
|
||||
is_running: 是否正在运行遗忘周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forgetting_strategy: ForgettingStrategy,
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""
|
||||
初始化遗忘调度器
|
||||
|
||||
Args:
|
||||
forgetting_strategy: 遗忘策略执行器实例
|
||||
connector: Neo4j 连接器实例
|
||||
"""
|
||||
self.forgetting_strategy = forgetting_strategy
|
||||
self.connector = connector
|
||||
self.is_running = False
|
||||
|
||||
logger.info("初始化遗忘调度器")
|
||||
|
||||
async def run_forgetting_cycle(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
max_merge_batch_size: int = 100,
|
||||
min_days_since_access: int = 30,
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行一次完整的遗忘周期
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 遗忘报告,包含:
|
||||
- merged_count: 融合的节点对数量
|
||||
- nodes_before: 遗忘前的节点总数
|
||||
- nodes_after: 遗忘后的节点总数
|
||||
- reduction_rate: 节点减少率(0-1)
|
||||
- duration_seconds: 执行耗时(秒)
|
||||
- start_time: 开始时间(ISO 格式)
|
||||
- end_time: 结束时间(ISO 格式)
|
||||
- failed_count: 失败的融合数量
|
||||
- success_rate: 成功率(0-1)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果已有遗忘周期正在运行
|
||||
"""
|
||||
# 检查是否已有遗忘周期在运行
|
||||
if self.is_running:
|
||||
raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成")
|
||||
|
||||
self.is_running = True
|
||||
start_time = datetime.now()
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
f"开始遗忘周期: group_id={group_id}, "
|
||||
f"max_batch={max_merge_batch_size}, "
|
||||
f"min_days={min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤1:统计遗忘前的节点数量
|
||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||
|
||||
# 步骤2:识别可遗忘的节点对
|
||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||
group_id=group_id,
|
||||
min_days_since_access=min_days_since_access
|
||||
)
|
||||
|
||||
total_forgettable = len(forgettable_pairs)
|
||||
logger.info(f"识别到 {total_forgettable} 个可遗忘节点对")
|
||||
|
||||
if total_forgettable == 0:
|
||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
report = {
|
||||
'merged_count': 0,
|
||||
'nodes_before': nodes_before,
|
||||
'nodes_after': nodes_before,
|
||||
'reduction_rate': 0.0,
|
||||
'duration_seconds': duration,
|
||||
'start_time': start_time_iso,
|
||||
'end_time': end_time.isoformat(),
|
||||
'failed_count': 0,
|
||||
'success_rate': 1.0
|
||||
}
|
||||
|
||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||
|
||||
return report
|
||||
|
||||
# 步骤3:按激活值排序(激活值最低的优先)
|
||||
# avg_activation 已经在 find_forgettable_nodes 中计算并排序
|
||||
# 这里只需要确认排序是正确的(升序)
|
||||
sorted_pairs = sorted(
|
||||
forgettable_pairs,
|
||||
key=lambda x: x['avg_activation']
|
||||
)
|
||||
|
||||
# 步骤4:限制批量大小
|
||||
pairs_to_process = sorted_pairs[:max_merge_batch_size]
|
||||
actual_batch_size = len(pairs_to_process)
|
||||
|
||||
logger.info(
|
||||
f"将处理 {actual_batch_size} 个节点对 "
|
||||
f"(限制: {max_merge_batch_size})"
|
||||
)
|
||||
|
||||
# 步骤5:批量融合节点,每 10% 记录进度
|
||||
merged_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0 # 跳过的节点对数量(节点已被处理)
|
||||
progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次
|
||||
|
||||
# 跟踪已处理的节点 ID,避免重复处理
|
||||
processed_statement_ids = set()
|
||||
processed_entity_ids = set()
|
||||
|
||||
# 预先过滤掉重复的节点对
|
||||
unique_pairs = []
|
||||
for pair in pairs_to_process:
|
||||
statement_id = pair['statement_id']
|
||||
entity_id = pair['entity_id']
|
||||
|
||||
# 如果节点已被标记为处理,跳过
|
||||
if statement_id in processed_statement_ids or entity_id in processed_entity_ids:
|
||||
skipped_count += 1
|
||||
logger.debug(
|
||||
f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]"
|
||||
)
|
||||
continue
|
||||
|
||||
# 标记节点为已处理
|
||||
processed_statement_ids.add(statement_id)
|
||||
processed_entity_ids.add(entity_id)
|
||||
unique_pairs.append(pair)
|
||||
|
||||
logger.info(
|
||||
f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对,"
|
||||
f"跳过 {skipped_count} 对重复节点"
|
||||
)
|
||||
|
||||
# 更新实际处理的批次大小
|
||||
actual_batch_size = len(unique_pairs)
|
||||
progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔
|
||||
|
||||
for idx, pair in enumerate(unique_pairs, start=1):
|
||||
statement_id = pair['statement_id']
|
||||
entity_id = pair['entity_id']
|
||||
|
||||
try:
|
||||
# 准备节点数据
|
||||
statement_node = {
|
||||
'statement_id': statement_id,
|
||||
'statement_text': pair['statement_text'],
|
||||
'statement_activation': pair['statement_activation'],
|
||||
'statement_importance': pair['statement_importance'],
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
entity_node = {
|
||||
'entity_id': entity_id,
|
||||
'entity_name': pair['entity_name'],
|
||||
'entity_type': pair['entity_type'],
|
||||
'entity_activation': pair['entity_activation'],
|
||||
'entity_importance': pair['entity_importance'],
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
# 融合节点
|
||||
await self.forgetting_strategy.merge_nodes_to_summary(
|
||||
statement_node=statement_node,
|
||||
entity_node=entity_node,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
merged_count += 1
|
||||
|
||||
# 进度跟踪:每 10% 记录一次
|
||||
if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size):
|
||||
progress_pct = (idx / actual_batch_size) * 100
|
||||
logger.info(
|
||||
f"遗忘进度: {idx}/{actual_batch_size} "
|
||||
f"({progress_pct:.1f}%), "
|
||||
f"已融合: {merged_count}, 失败: {failed_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
# 检查是否是节点不存在的错误
|
||||
if "nodes may not exist" in str(e):
|
||||
logger.warning(
|
||||
f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): "
|
||||
f"Statement[{statement_id}] + Entity[{entity_id}]"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"融合节点对失败 ({idx}/{actual_batch_size}): "
|
||||
f"Statement[{statement_id}] + Entity[{entity_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
# 继续处理剩余节点
|
||||
continue
|
||||
|
||||
# 步骤6:统计遗忘后的节点数量
|
||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||
|
||||
# 步骤7:生成遗忘报告
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# 计算节点减少率
|
||||
if nodes_before > 0:
|
||||
reduction_rate = (nodes_before - nodes_after) / nodes_before
|
||||
else:
|
||||
reduction_rate = 0.0
|
||||
|
||||
# 计算成功率
|
||||
if actual_batch_size > 0:
|
||||
success_rate = merged_count / actual_batch_size
|
||||
else:
|
||||
success_rate = 1.0
|
||||
|
||||
report = {
|
||||
'merged_count': merged_count,
|
||||
'nodes_before': nodes_before,
|
||||
'nodes_after': nodes_after,
|
||||
'reduction_rate': reduction_rate,
|
||||
'duration_seconds': duration,
|
||||
'start_time': start_time_iso,
|
||||
'end_time': end_time.isoformat(),
|
||||
'failed_count': failed_count,
|
||||
'success_rate': success_rate
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"遗忘周期完成: "
|
||||
f"融合 {merged_count} 对节点, "
|
||||
f"失败 {failed_count} 对, "
|
||||
f"节点减少 {nodes_before - nodes_after} 个 "
|
||||
f"({reduction_rate:.2%}), "
|
||||
f"耗时 {duration:.2f} 秒"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"遗忘周期执行失败: {str(e)}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _count_knowledge_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计知识层节点总数
|
||||
|
||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
|
||||
Returns:
|
||||
int: 知识层节点总数
|
||||
"""
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
params = {}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
if results:
|
||||
return results[0]['total']
|
||||
return 0
|
||||
@@ -1,643 +0,0 @@
|
||||
"""
|
||||
遗忘策略执行器模块
|
||||
|
||||
本模块实现基于 ACT-R 激活值的遗忘策略,负责:
|
||||
1. 识别低激活值的节点对(Statement-Entity)
|
||||
2. 将低激活值节点融合为 MemorySummary 节点
|
||||
3. 使用 LLM 生成高质量摘要(可选)
|
||||
4. 保留溯源信息并删除原始节点
|
||||
|
||||
Classes:
|
||||
ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForgettingStrategy:
|
||||
"""
|
||||
遗忘策略执行器
|
||||
|
||||
基于 ACT-R 激活值识别和融合低价值记忆节点。
|
||||
实现了完整的遗忘周期:识别 → 融合 → 删除。
|
||||
|
||||
核心功能:
|
||||
1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对
|
||||
2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性
|
||||
3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接)
|
||||
4. 溯源保留:记录原始节点 ID,保持可追溯性
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j 连接器实例
|
||||
actr_calculator: ACT-R 激活值计算器实例
|
||||
forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
forgetting_threshold: float = 0.3,
|
||||
enable_llm_summary: bool = True
|
||||
):
|
||||
"""
|
||||
初始化遗忘策略执行器
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器实例
|
||||
actr_calculator: ACT-R 激活值计算器实例
|
||||
forgetting_threshold: 遗忘阈值(默认 0.3)
|
||||
enable_llm_summary: 是否启用 LLM 摘要生成(默认 True)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.forgetting_threshold = forgetting_threshold
|
||||
self.enable_llm_summary = enable_llm_summary
|
||||
|
||||
logger.info(
|
||||
f"初始化遗忘策略执行器: threshold={forgetting_threshold}, "
|
||||
f"enable_llm_summary={enable_llm_summary}"
|
||||
)
|
||||
|
||||
async def calculate_forgetting_score(
|
||||
self,
|
||||
activation_value: float
|
||||
) -> float:
|
||||
"""
|
||||
计算遗忘分数
|
||||
|
||||
遗忘分数 = 1 - 激活值
|
||||
分数越高,越容易被遗忘。
|
||||
|
||||
注意:激活值已经包含了 importance_score 的权重,
|
||||
因此不需要单独考虑重要性分数。
|
||||
|
||||
Args:
|
||||
activation_value: 节点的激活值(0-1)
|
||||
|
||||
Returns:
|
||||
float: 遗忘分数(0-1),值越高越容易被遗忘
|
||||
"""
|
||||
return 1.0 - activation_value
|
||||
|
||||
async def find_forgettable_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
min_days_since_access: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
识别可遗忘的节点对
|
||||
|
||||
查找满足以下条件的 Statement-Entity 节点对:
|
||||
1. 两个节点的激活值都低于遗忘阈值
|
||||
2. 两个节点都至少 min_days_since_access 天未被访问
|
||||
3. Statement 和 Entity 之间存在关系边
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含:
|
||||
- statement_id: Statement 节点 ID
|
||||
- statement_text: Statement 文本内容
|
||||
- statement_activation: Statement 激活值
|
||||
- statement_importance: Statement 重要性分数
|
||||
- statement_last_access: Statement 最后访问时间
|
||||
- entity_id: Entity 节点 ID
|
||||
- entity_name: Entity 名称
|
||||
- entity_type: Entity 类型
|
||||
- entity_activation: Entity 激活值
|
||||
- entity_importance: Entity 重要性分数
|
||||
- entity_last_access: Entity 最后访问时间
|
||||
- avg_activation: 平均激活值(用于排序)
|
||||
"""
|
||||
# 计算时间阈值
|
||||
cutoff_time = datetime.now() - timedelta(days=min_days_since_access)
|
||||
cutoff_time_iso = cutoff_time.isoformat()
|
||||
|
||||
# 构建查询
|
||||
query = """
|
||||
MATCH (s:Statement)-[r]-(e:ExtractedEntity)
|
||||
WHERE s.activation_value IS NOT NULL
|
||||
AND e.activation_value IS NOT NULL
|
||||
AND s.activation_value < $threshold
|
||||
AND e.activation_value < $threshold
|
||||
AND s.last_access_time < $cutoff_time
|
||||
AND e.last_access_time < $cutoff_time
|
||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
RETURN s.id as statement_id,
|
||||
s.statement as statement_text,
|
||||
s.activation_value as statement_activation,
|
||||
s.importance_score as statement_importance,
|
||||
s.last_access_time as statement_last_access,
|
||||
e.id as entity_id,
|
||||
e.name as entity_name,
|
||||
e.entity_type as entity_type,
|
||||
e.activation_value as entity_activation,
|
||||
e.importance_score as entity_importance,
|
||||
e.last_access_time as entity_last_access,
|
||||
(s.activation_value + e.activation_value) / 2.0 as avg_activation
|
||||
ORDER BY avg_activation ASC
|
||||
"""
|
||||
|
||||
params = {
|
||||
'threshold': self.forgetting_threshold,
|
||||
'cutoff_time': cutoff_time_iso
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
logger.info(
|
||||
f"识别到 {len(results)} 个可遗忘节点对 "
|
||||
f"(threshold={self.forgetting_threshold}, "
|
||||
f"min_days={min_days_since_access})"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def merge_nodes_to_summary(
|
||||
self,
|
||||
statement_node: Dict[str, Any],
|
||||
entity_node: Dict[str, Any],
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
将 Statement 和 Entity 节点融合为 MemorySummary 节点
|
||||
|
||||
融合过程:
|
||||
1. 生成摘要内容(使用 LLM 或简单拼接)
|
||||
2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数
|
||||
3. 删除原始 Statement 和 Entity 节点
|
||||
4. 保留溯源信息(original_statement_id, original_entity_id)
|
||||
|
||||
Args:
|
||||
statement_node: Statement 节点数据,必须包含:
|
||||
- statement_id: 节点 ID
|
||||
- statement_text: 文本内容
|
||||
- statement_activation: 激活值
|
||||
- statement_importance: 重要性分数
|
||||
entity_node: Entity 节点数据,必须包含:
|
||||
- entity_id: 节点 ID
|
||||
- entity_name: 实体名称
|
||||
- entity_type: 实体类型
|
||||
- entity_activation: 激活值
|
||||
- entity_importance: 重要性分数
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
str: 创建的 MemorySummary 节点 ID
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点数据不完整
|
||||
RuntimeError: 如果融合操作失败
|
||||
"""
|
||||
# 验证输入数据
|
||||
required_statement_keys = [
|
||||
'statement_id', 'statement_text',
|
||||
'statement_activation', 'statement_importance'
|
||||
]
|
||||
required_entity_keys = [
|
||||
'entity_id', 'entity_name', 'entity_type',
|
||||
'entity_activation', 'entity_importance'
|
||||
]
|
||||
|
||||
for key in required_statement_keys:
|
||||
if key not in statement_node:
|
||||
raise ValueError(f"Statement 节点缺少必需字段: {key}")
|
||||
|
||||
for key in required_entity_keys:
|
||||
if key not in entity_node:
|
||||
raise ValueError(f"Entity 节点缺少必需字段: {key}")
|
||||
|
||||
# 验证实体类型:不允许融合 Person 类型的实体
|
||||
if entity_node.get('entity_type') == 'Person':
|
||||
raise ValueError(
|
||||
f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, "
|
||||
f"entity_name={entity_node.get('entity_name')}"
|
||||
)
|
||||
|
||||
# 提取节点信息
|
||||
statement_id = statement_node['statement_id']
|
||||
statement_text = statement_node['statement_text']
|
||||
statement_activation = statement_node['statement_activation']
|
||||
statement_importance = statement_node['statement_importance']
|
||||
|
||||
entity_id = entity_node['entity_id']
|
||||
entity_name = entity_node['entity_name']
|
||||
entity_type = entity_node['entity_type']
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
statement_text=statement_text,
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 生成标题和类型(使用LLM)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import generate_title_and_type_for_summary
|
||||
|
||||
# 获取 LLM 客户端
|
||||
llm_client = None
|
||||
if config_id is not None and db is not None:
|
||||
try:
|
||||
llm_client = await self._get_llm_client(db, config_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
|
||||
|
||||
# 生成标题和类型
|
||||
try:
|
||||
if llm_client is not None:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
)
|
||||
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
|
||||
else:
|
||||
logger.warning("LLM 客户端不可用,使用默认标题和类型")
|
||||
title = "未命名"
|
||||
episodic_type = "conversation"
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
|
||||
title = "未命名"
|
||||
episodic_type = "conversation"
|
||||
|
||||
# 计算继承的激活值和重要性(取较高值)
|
||||
inherited_activation = max(statement_activation, entity_activation)
|
||||
inherited_importance = max(statement_importance, entity_importance)
|
||||
|
||||
# 创建 MemorySummary 节点
|
||||
current_time = datetime.now()
|
||||
current_time_iso = current_time.isoformat()
|
||||
|
||||
# 生成新的 MemorySummary ID
|
||||
import uuid
|
||||
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 使用事务创建 MemorySummary 并删除原节点
|
||||
async def merge_transaction(tx, **params):
|
||||
"""事务函数:创建摘要节点并删除原节点"""
|
||||
query = """
|
||||
// 首先检查节点是否存在
|
||||
OPTIONAL MATCH (s:Statement {id: $statement_id})
|
||||
OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id})
|
||||
|
||||
// 如果任一节点不存在,直接返回 null(不执行后续操作)
|
||||
WITH s, e
|
||||
WHERE s IS NOT NULL AND e IS NOT NULL
|
||||
|
||||
// 创建 MemorySummary 节点
|
||||
CREATE (ms:MemorySummary {
|
||||
id: $summary_id,
|
||||
summary: $summary_text,
|
||||
name: $title,
|
||||
memory_type: $episodic_type,
|
||||
original_statement_id: $statement_id,
|
||||
original_entity_id: $entity_id,
|
||||
activation_value: $inherited_activation,
|
||||
importance_score: $inherited_importance,
|
||||
access_history: [$current_time],
|
||||
last_access_time: $current_time,
|
||||
access_count: 1,
|
||||
version: 1,
|
||||
group_id: $group_id,
|
||||
created_at: datetime($current_time),
|
||||
merged_at: datetime($current_time)
|
||||
})
|
||||
|
||||
// 转移 Statement 的出边到 MemorySummary(只转移目标节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (s)-[r_out]->(target)
|
||||
WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_out),
|
||||
new_rel.original_relationship_type = type(r_out),
|
||||
new_rel.merged_from_statement = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Statement 的入边到 MemorySummary(只转移源节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (source)-[r_in]->(s)
|
||||
WHERE r_in IS NOT NULL AND source IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_in),
|
||||
new_rel.original_relationship_type = type(r_in),
|
||||
new_rel.merged_from_statement = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Entity 的出边到 MemorySummary(只转移目标节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (e)-[r_out]->(target)
|
||||
WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_out),
|
||||
new_rel.original_relationship_type = type(r_out),
|
||||
new_rel.merged_from_entity = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Entity 的入边到 MemorySummary(只转移源节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (source)-[r_in]->(e)
|
||||
WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_in),
|
||||
new_rel.original_relationship_type = type(r_in),
|
||||
new_rel.merged_from_entity = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 删除原始节点
|
||||
WITH ms, s, e
|
||||
DETACH DELETE s, e
|
||||
|
||||
RETURN ms.id as summary_id
|
||||
"""
|
||||
|
||||
result = await tx.run(query, **params)
|
||||
record = await result.single()
|
||||
|
||||
if not record:
|
||||
raise RuntimeError("Failed to create MemorySummary node - nodes may not exist")
|
||||
|
||||
return record['summary_id']
|
||||
|
||||
params = {
|
||||
'summary_id': summary_id,
|
||||
'summary_text': summary_text,
|
||||
'title': title,
|
||||
'episodic_type': episodic_type,
|
||||
'statement_id': statement_id,
|
||||
'entity_id': entity_id,
|
||||
'inherited_activation': inherited_activation,
|
||||
'inherited_importance': inherited_importance,
|
||||
'current_time': current_time_iso,
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
try:
|
||||
created_summary_id = await self.connector.execute_write_transaction(
|
||||
merge_transaction,
|
||||
**params
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] "
|
||||
f"-> MemorySummary[{created_summary_id}], "
|
||||
f"activation={inherited_activation:.4f}, "
|
||||
f"importance={inherited_importance:.4f}"
|
||||
)
|
||||
|
||||
return created_summary_id
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细的错误信息,包括异常类型和堆栈
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(
|
||||
f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], "
|
||||
f"错误类型: {type(e).__name__}, "
|
||||
f"错误信息: {str(e)}, "
|
||||
f"详细堆栈:\n{error_details}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"融合节点失败: {str(e)}"
|
||||
) from e
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _generate_summary(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
生成摘要内容
|
||||
|
||||
优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败,
|
||||
则降级到简单文本拼接。
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
str: 生成的摘要文本(最多 200 个字符)
|
||||
"""
|
||||
# 如果配置禁用 LLM 摘要,直接使用简单拼接
|
||||
if not self.enable_llm_summary:
|
||||
logger.info("LLM 摘要生成已禁用,使用简单拼接")
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
# 尝试获取 LLM 客户端
|
||||
llm_client = None
|
||||
if config_id is not None and db is not None:
|
||||
try:
|
||||
llm_client = await self._get_llm_client(db, config_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
|
||||
|
||||
# 如果没有 LLM 客户端,直接使用简单拼接
|
||||
if llm_client is None:
|
||||
logger.info("未能获取 LLM 客户端,使用简单拼接")
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
# 尝试使用 LLM 生成摘要
|
||||
try:
|
||||
summary = await self._generate_llm_summary(
|
||||
statement_text=statement_text,
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
llm_client=llm_client
|
||||
)
|
||||
|
||||
# 限制长度为 200 个字符
|
||||
if len(summary) > 200:
|
||||
summary = f"{summary[:197]}..."
|
||||
|
||||
logger.info(f"使用 LLM 生成摘要: {summary}")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"LLM 摘要生成失败,降级到简单拼接: {str(e)}"
|
||||
)
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
async def _get_llm_client(self, db, config_id: int):
|
||||
"""
|
||||
从数据库获取 LLM 客户端
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例,如果无法获取则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 从数据库读取配置
|
||||
repository = DataConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None or db_config.llm_id is None:
|
||||
logger.warning(f"配置 {config_id} 不存在或未设置 llm_id")
|
||||
return None
|
||||
|
||||
# 创建 LLM 客户端
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(db_config.llm_id))
|
||||
|
||||
logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}")
|
||||
return llm_client
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 LLM 客户端失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _generate_llm_summary(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
llm_client
|
||||
) -> str:
|
||||
"""
|
||||
使用 LLM 生成高质量摘要
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
llm_client: LLM 客户端实例
|
||||
|
||||
Returns:
|
||||
str: LLM 生成的摘要文本
|
||||
|
||||
Raises:
|
||||
Exception: 如果 LLM 调用失败
|
||||
"""
|
||||
# 构建提示词
|
||||
prompt = f"""请为以下记忆片段生成一个简洁的摘要(不超过200个字符):
|
||||
|
||||
实体名称: {entity_name}
|
||||
实体类型: {entity_type}
|
||||
陈述内容: {statement_text}
|
||||
|
||||
要求:
|
||||
1. 摘要应该保留核心语义信息
|
||||
2. 长度不超过200个字符
|
||||
3. 使用简洁、自然的中文表达
|
||||
4. 只返回摘要文本,不要包含其他内容
|
||||
|
||||
摘要:"""
|
||||
|
||||
# 调用 LLM(直接传递 prompt 字符串)
|
||||
response = await llm_client.chat(prompt)
|
||||
|
||||
# 提取摘要文本
|
||||
if isinstance(response, str):
|
||||
summary = response.strip()
|
||||
elif hasattr(response, 'content'):
|
||||
summary = response.content.strip()
|
||||
else:
|
||||
summary = str(response).strip()
|
||||
|
||||
return summary
|
||||
|
||||
def _simple_concatenation(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str
|
||||
) -> str:
|
||||
"""
|
||||
简单文本拼接生成摘要
|
||||
|
||||
降级策略:当 LLM 不可用时使用。
|
||||
格式:[实体类型]实体名称: 陈述内容
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
|
||||
Returns:
|
||||
str: 拼接的摘要文本(最多 200 个字符)
|
||||
"""
|
||||
# 构建简单摘要
|
||||
summary = f"[{entity_type}]{entity_name}: {statement_text}"
|
||||
|
||||
# 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数)
|
||||
if len(summary) > 200:
|
||||
# 截断并添加省略号
|
||||
summary = f"{summary[:197]}..."
|
||||
|
||||
return summary
|
||||
|
||||
@@ -2,39 +2,52 @@
|
||||
"memory_verify": {
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合。"
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。"
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。"
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。"
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。"
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。"
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。"
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。"
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。"
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。"
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
{
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"name": "用户"
|
||||
@@ -42,6 +55,9 @@
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"description": "用于个人身份识别的数据",
|
||||
"name": "身份信息"
|
||||
@@ -49,6 +65,9 @@
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"description": "用户的银行卡号码",
|
||||
"name": "6222023847595898"
|
||||
@@ -57,24 +76,33 @@
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"aliases": ["上海办"],
|
||||
"description": "位于上海的工作办公场所",
|
||||
"name": "上海办公室"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"name": "北京"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"name": "身份证号"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +387,7 @@ class ReflectionEngine:
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found=''
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
@@ -396,23 +396,7 @@ class ReflectionEngine:
|
||||
'conflict': item['conflict']
|
||||
}
|
||||
cleaned_conflict_data.append(cleaned_item)
|
||||
cleaned_conflict_data_=[]
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
if k not in REMOVE_KEYS
|
||||
}
|
||||
cleaned_data.append(cleaned_row)
|
||||
cleaned_item = {
|
||||
"data": cleaned_data,
|
||||
"conflict": item.get("conflict"),
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||
if not solved_data:
|
||||
|
||||
@@ -316,96 +316,3 @@ async def render_emotion_suggestions_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("user_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('user summary', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('user_summary.jinja2', {
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements)
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_memory_insight_prompt(
|
||||
domain_distribution: str = None,
|
||||
active_periods: str = None,
|
||||
social_connections: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory insight prompt using the memory_insight.jinja2 template.
|
||||
|
||||
Args:
|
||||
domain_distribution: 核心领域分布信息
|
||||
active_periods: 活跃时段信息
|
||||
social_connections: 社交关联信息
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("memory_insight.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
domain_distribution=domain_distribution,
|
||||
active_periods=active_periods,
|
||||
social_connections=social_connections
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('memory insight', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('memory_insight.jinja2', {
|
||||
'has_domain_distribution': bool(domain_distribution),
|
||||
'has_active_periods': bool(active_periods),
|
||||
'has_social_connections': bool(social_connections)
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||
"""
|
||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||
|
||||
Args:
|
||||
content: The content of the episodic memory summary to analyze
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||
rendered_prompt = template.render(content=content)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('episodic_type_classification.jinja2', {
|
||||
'content_len': len(content) if content else 0
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
=== Task ===
|
||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||
|
||||
=== Requirements ===
|
||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||
- Classify into exactly one category based on the primary theme
|
||||
- Be specific and avoid ambiguity
|
||||
- Output must be valid JSON conforming to the schema below
|
||||
|
||||
=== Input ===
|
||||
{{ content }}
|
||||
|
||||
=== Category Definitions ===
|
||||
|
||||
1. **conversation**: Daily communication, chat, discussion, and social interactions
|
||||
- Keywords: chat, communication, discussion, dialogue, exchange
|
||||
|
||||
2. **project_work**: Work-related tasks, projects, meetings, and collaboration
|
||||
- Keywords: project, task, work, meeting, collaboration, business, client
|
||||
|
||||
3. **learning**: Acquiring new knowledge, skill development, reading, and research
|
||||
- Keywords: learning, reading, research, knowledge, skill, course, training
|
||||
|
||||
4. **decision**: Making important decisions, choices, and planning
|
||||
- Keywords: decision, choice, planning, consideration, evaluation, weighing
|
||||
|
||||
5. **important_event**: Major events, milestones, and special experiences
|
||||
- Keywords: important, major, milestone, special, memorable, celebration
|
||||
|
||||
=== Analysis Steps ===
|
||||
1. Read the episodic memory content carefully
|
||||
2. Identify the core theme and context
|
||||
3. Extract a concise title
|
||||
4. Compare against category definitions and keywords
|
||||
5. Select the best matching category
|
||||
6. If multiple categories apply, choose the primary one
|
||||
|
||||
=== Output Schema ===
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure
|
||||
2. Escape any quotation marks within string values using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
Return only a JSON object with title and type fields:
|
||||
{
|
||||
"title": "Generated title here",
|
||||
"type": "Category type here"
|
||||
}
|
||||
|
||||
The type field must be exactly one of:
|
||||
- conversation
|
||||
- project_work
|
||||
- learning
|
||||
- decision
|
||||
- important_event
|
||||
|
||||
@@ -86,5 +86,5 @@
|
||||
- **quality_assessment**:
|
||||
quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
模式参考:{{ json_schema }}
|
||||
@@ -12,34 +12,7 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法"
|
||||
- **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库"
|
||||
* Set to `false` for:
|
||||
- **People:** "John Smith", "Dr. Wang", "张明", "王博士"
|
||||
- **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学"
|
||||
- **Locations:** "Beijing", "Central Park", "北京", "中央公园"
|
||||
- **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议"
|
||||
- **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋"
|
||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
||||
* The example should be:
|
||||
- **Specific and concrete**: Use real-world scenarios or applications
|
||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
||||
- **In the same language as the entity name**
|
||||
* Examples:
|
||||
- Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗"
|
||||
- Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed"
|
||||
- Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy"
|
||||
- Entity: "人工智能" → example: "如:智能客服、自动驾驶"
|
||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
||||
- Extract entities with their types, context-independent descriptions, and aliases
|
||||
- **Aliases Extraction (Important):**
|
||||
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
|
||||
* **DO NOT translate or add aliases in different languages**
|
||||
@@ -111,27 +84,21 @@ Output:
|
||||
"name": "I",
|
||||
"type": "Person",
|
||||
"description": "The user",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Paris",
|
||||
"type": "Location",
|
||||
"description": "Capital city of France",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "Louvre",
|
||||
"type": "Location",
|
||||
"description": "World-famous museum located in Paris",
|
||||
"example": "",
|
||||
"aliases": ["Louvre Museum"],
|
||||
"is_explicit_memory": false
|
||||
"aliases": ["Louvre Museum"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -163,27 +130,21 @@ Output:
|
||||
"name": "John Smith",
|
||||
"type": "Person",
|
||||
"description": "Individual person name",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Google",
|
||||
"type": "Organization",
|
||||
"description": "American technology company",
|
||||
"example": "",
|
||||
"aliases": ["Google LLC", "Alphabet Inc."],
|
||||
"is_explicit_memory": false
|
||||
"aliases": ["Google LLC", "Alphabet Inc."]
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI product development",
|
||||
"type": "Concept",
|
||||
"type": "WorkRole",
|
||||
"description": "Artificial intelligence product development work",
|
||||
"example": "e.g., developing chatbots, recommendation systems",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -215,27 +176,21 @@ Output:
|
||||
"name": "我",
|
||||
"type": "Person",
|
||||
"description": "用户本人",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "巴黎",
|
||||
"type": "Location",
|
||||
"description": "法国首都城市",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "卢浮宫",
|
||||
"type": "Location",
|
||||
"description": "位于巴黎的世界著名博物馆",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -267,27 +222,21 @@ Output:
|
||||
"name": "张明",
|
||||
"type": "Person",
|
||||
"description": "个人姓名",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
"aliases": []
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "腾讯",
|
||||
"type": "Organization",
|
||||
"description": "中国科技公司",
|
||||
"example": "",
|
||||
"aliases": ["腾讯控股", "腾讯公司"],
|
||||
"is_explicit_memory": false
|
||||
"aliases": ["腾讯控股", "腾讯公司"]
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI产品开发",
|
||||
"type": "Concept",
|
||||
"type": "WorkRole",
|
||||
"description": "人工智能产品研发工作",
|
||||
"example": "如:开发智能客服机器人、推荐系统",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
"aliases": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -302,9 +251,7 @@ Output:
|
||||
"name": "Tripod",
|
||||
"type": "Equipment",
|
||||
"description": "Photography equipment accessory",
|
||||
"example": "",
|
||||
"aliases": ["Camera Tripod"],
|
||||
"is_explicit_memory": false
|
||||
"aliases": ["Camera Tripod"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -319,9 +266,7 @@ Output:
|
||||
"name": "三脚架",
|
||||
"type": "Equipment",
|
||||
"description": "摄影器材配件",
|
||||
"example": "",
|
||||
"aliases": ["相机三脚架"],
|
||||
"is_explicit_memory": false
|
||||
"aliases": ["相机三脚架"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ')}}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
===Task===
|
||||
|
||||
Your task is to generate a comprehensive memory insight report based on the provided data analysis. The report should include four distinct sections that capture different aspects of the user's memory patterns and characteristics.
|
||||
|
||||
|
||||
===Inputs===
|
||||
{% if domain_distribution %}
|
||||
- 核心领域分布: {{ domain_distribution }}
|
||||
{% endif %}
|
||||
{% if active_periods %}
|
||||
- 活跃时段: {{ active_periods }}
|
||||
{% endif %}
|
||||
{% if social_connections %}
|
||||
- 社交关联: {{ social_connections }}
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Report Generation Requirements===
|
||||
|
||||
**General Guidelines:**
|
||||
1. Base your analysis ONLY on the provided data - do not speculate or fabricate information
|
||||
2. Use objective third-person descriptions with a professional and analytical tone
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
5. If a dimension lacks data, skip that section or provide a brief note
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
1. **总体概述 (Overview)** (100-150 Chinese characters)
|
||||
- Focus on: Overall analysis of user profile based on interaction logs
|
||||
- Describe the user's main role, work network, and collaboration spirit
|
||||
- Use professional, data-driven language style
|
||||
- Example reference: "通过对156次交互日志的深度分析,系统发现三层一位主要用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神。"
|
||||
|
||||
2. **行为模式 (Behavior Pattern)** (80-120 Chinese characters)
|
||||
- Focus on: Work patterns, time regularity, and behavioral characteristics
|
||||
- Describe weekly work patterns and time preferences
|
||||
- Use objective, analytical language
|
||||
- Example reference: "张三的工作模式呈现出鲜明的周期性:周一通常用于规划和会议,周三周四专注于产品设计和用户研究,周五进行总结和复盘。他倾向于在上午进行头脑风暴,下午处理执行性工作。"
|
||||
|
||||
3. **关键发现 (Key Findings)** (3-4 bullet points, 30-50 characters each)
|
||||
- Focus on: Specific, insightful observations about user behavior and preferences
|
||||
- Use bullet points (•) format
|
||||
- Each finding should be concrete and data-supported
|
||||
- Example reference:
|
||||
"• 在产品决策中,张三总是优先考虑用户反应,这在68%的决策记录中得到体现
|
||||
• 他善于使用数据可视化工具来支持论点,这种习惯在项目管理中发挥了重要作用
|
||||
• 团队成员对他的评价中,"思路清晰"和"思路敏捷"两个关键词出现频率最高
|
||||
• 他对AI机器学习领域保持持续关注,近3个月参加了7次相关培训"
|
||||
|
||||
4. **成长轨迹 (Growth Trajectory)** (100-150 Chinese characters)
|
||||
- Focus on: User's growth journey, key milestones, and capability improvements
|
||||
- Organize content chronologically
|
||||
- Highlight role changes and achievements
|
||||
- Use positive, encouraging tone
|
||||
- Example reference: "从入职时的产品经理成长为高级产品经理,张三在产品单独、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他的成长轨迹显示出对新技术的持续学习和对产品思维的不断深化。"
|
||||
|
||||
|
||||
===Output Format (MUST STRICTLY FOLLOW)===
|
||||
|
||||
【总体概述】
|
||||
[100-150 characters describing overall user profile and work network based on interaction analysis]
|
||||
|
||||
【行为模式】
|
||||
[80-120 characters describing work patterns, time regularity, and behavioral characteristics]
|
||||
|
||||
【关键发现】
|
||||
• [First key finding with data support, 30-50 characters]
|
||||
• [Second key finding with data support, 30-50 characters]
|
||||
• [Third key finding with data support, 30-50 characters]
|
||||
• [Fourth key finding with data support, 30-50 characters]
|
||||
|
||||
【成长轨迹】
|
||||
[100-150 characters describing growth journey, milestones, and capability improvements]
|
||||
|
||||
|
||||
===Example===
|
||||
|
||||
Example Input:
|
||||
- 核心领域分布: 产品管理(38%), 数据分析(24%), 团队协作(21%)
|
||||
- 活跃时段: 用户在每年的 4 和 10 月最为活跃
|
||||
- 社交关联: 与用户"李明"拥有最多共同记忆(47条),时间范围主要在 2020-2023
|
||||
|
||||
Example Output:
|
||||
【总体概述】
|
||||
通过对156次交互日志的深度分析,系统发现张三是一位主要从事用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神,在产品管理、数据分析和团队协作三个领域都有深入的实践。
|
||||
|
||||
【行为模式】
|
||||
张三的工作模式呈现出鲜明的周期性:周一通常用于规划和会议,周三周四专注于产品设计和用户研究,周五进行总结和复盘。他倾向于在上午进行头脑风暴,下午处理执行性工作。每年4月和10月是他最活跃的时期。
|
||||
|
||||
【关键发现】
|
||||
• 在产品决策中,张三总是优先考虑用户反应,这在68%的决策记录中得到体现
|
||||
• 他善于使用数据可视化工具来支持论点,这种习惯在项目管理中发挥了重要作用
|
||||
• 团队成员对他的评价中,"思路清晰"和"思路敏捷"两个关键词出现频率最高
|
||||
• 他对AI机器学习领域保持持续关注,近3个月参加了7次相关培训
|
||||
|
||||
【成长轨迹】
|
||||
从入职时的产品经理成长为高级产品经理,张三在产品规划、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他与李明的47条共同记忆见证了他的成长历程。
|
||||
|
||||
===End of Example===
|
||||
|
||||
|
||||
===Reflection Process===
|
||||
|
||||
After generating the report, perform the following self-review steps:
|
||||
|
||||
**Step 1: Data Grounding Check**
|
||||
- Verify all statements are supported by the provided data
|
||||
- Ensure no fabricated or speculated information is included
|
||||
- Confirm all claims can be traced back to the input data
|
||||
|
||||
**Step 2: Format Compliance**
|
||||
- Verify each section follows the specified format with section headers
|
||||
- Check character count limits for each section
|
||||
- Ensure proper use of section markers (【】)
|
||||
- Verify bullet points format for Key Findings section
|
||||
|
||||
**Step 3: Tone and Style Review**
|
||||
- Confirm objective third-person perspective is maintained
|
||||
- Check for excessive adjectives or empty phrases
|
||||
- Verify professional and analytical tone throughout
|
||||
|
||||
**Step 4: Completeness Check**
|
||||
- Ensure all four sections are present and complete
|
||||
- Verify each section addresses its specific focus area
|
||||
- Confirm the report provides actionable insights
|
||||
|
||||
|
||||
===Output Requirements===
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS be Chinese (Simplified)
|
||||
- All section content must be in Chinese
|
||||
- Section headers must use the specified Chinese format: 【总体概述】【行为模式】【关键发现】【成长轨迹】
|
||||
|
||||
**FORMAT REQUIREMENT:**
|
||||
- Each section must start with its header on a new line
|
||||
- Content follows immediately after the header
|
||||
- Sections are separated by blank lines
|
||||
- Key Findings section must use bullet points (•)
|
||||
- Strictly adhere to character limits for each section
|
||||
|
||||
**CONTENT REQUIREMENT:**
|
||||
- Only use provided data points
|
||||
- Do not fabricate or speculate information
|
||||
- If data is insufficient for a section, provide a brief note or skip
|
||||
- Maintain professional, analytical tone throughout
|
||||
@@ -9,7 +9,9 @@
|
||||
|
||||
## 任务目标
|
||||
作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。
|
||||
|
||||
**数据关系**: statement_databasets中的statement_id对应data中的记录,statement_created_at为用户输入时间。
|
||||
|
||||
**处理模式**:
|
||||
- memory_verify=false: 仅处理数据冲突
|
||||
- memory_verify=true: 处理数据冲突 + 隐私脱敏
|
||||
@@ -109,8 +111,7 @@
|
||||
- 隐私保护优先: 所有输出记录必须完成隐私脱敏
|
||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容,
|
||||
,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
@@ -157,9 +158,8 @@
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,
|
||||
不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种",
|
||||
"solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)"
|
||||
"reason": "该冲突类型的原因分析",
|
||||
"solution": "该冲突类型的解决方案"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
@@ -182,5 +182,4 @@
|
||||
- **resolved.change**: 包含详细变更信息
|
||||
- 无需修改的冲突类型resolved为null
|
||||
- 与baseline不匹配的冲突类型不包含在results中
|
||||
模式参考: {{ json_schema }}
|
||||
|
||||
模式参考: {{ json_schema }}
|
||||
@@ -1,114 +0,0 @@
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ')}}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
===Task===
|
||||
|
||||
Your task is to generate a comprehensive user profile based on the provided entities and statements. The profile should include four distinct sections that capture different aspects of the user's identity and characteristics.
|
||||
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
- User ID: {{ user_id }}
|
||||
{% endif %}
|
||||
{% if entities %}
|
||||
- Core Entities & Frequency: {{ entities }}
|
||||
{% endif %}
|
||||
{% if statements %}
|
||||
- Representative Statement Samples: {{ statements }}
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Profile Generation Requirements===
|
||||
|
||||
**General Guidelines:**
|
||||
1. Base your analysis ONLY on the provided data - do not speculate or fabricate information
|
||||
2. Use objective third-person descriptions with a restrained and neutral tone
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
1. **Basic Introduction** (4-5 sentences, max 150 Chinese characters)
|
||||
- Focus on: identity, occupation, location, and other basic demographic information
|
||||
- Provide factual background about who the user is
|
||||
|
||||
2. **Personality Traits** (2-3 sentences, max 80 Chinese characters)
|
||||
- Focus on: personality characteristics, behavioral habits, communication style
|
||||
- Describe observable patterns in how the user interacts and behaves
|
||||
|
||||
3. **Core Values** (1-2 sentences, max 50 Chinese characters)
|
||||
- Focus on: values, beliefs, goals, and aspirations
|
||||
- Capture what matters most to the user and what drives their decisions
|
||||
|
||||
4. **One-Sentence Summary** (1 sentence, max 40 Chinese characters)
|
||||
- Provide a highly condensed characterization of the user's core traits
|
||||
- Similar to a personal tagline or motto that captures their essence
|
||||
|
||||
|
||||
===Output Format (MUST STRICTLY FOLLOW)===
|
||||
|
||||
【基本介绍】
|
||||
[4-5 sentences describing the user's basic identity, occupation, and location]
|
||||
|
||||
【性格特点】
|
||||
[2-3 sentences describing the user's personality traits, behavioral habits, and communication style]
|
||||
|
||||
【核心价值观】
|
||||
[1-2 sentences describing the user's values, beliefs, and goals]
|
||||
|
||||
【一句话总结】
|
||||
[1 sentence providing a highly condensed summary of the user's core characteristics]
|
||||
|
||||
|
||||
===Example===
|
||||
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
|
||||
|
||||
Example Output:
|
||||
【基本介绍】
|
||||
我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
|
||||
【性格特点】
|
||||
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
|
||||
|
||||
【核心价值观】
|
||||
用户至上、数据驱动、持续学习、团队协作
|
||||
|
||||
【一句话总结】
|
||||
"让每一个产品决策都充满温度。"
|
||||
|
||||
===End of Example===
|
||||
|
||||
|
||||
===Internal Quality Checks (DO NOT OUTPUT)===
|
||||
|
||||
Before generating your final output, internally verify:
|
||||
1. All content is grounded in provided data (no fabrication)
|
||||
2. Format follows the specified structure with correct headers
|
||||
3. Tone is objective, third-person, and neutral
|
||||
4. All four sections are complete and within character limits
|
||||
|
||||
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
|
||||
|
||||
|
||||
===Output Requirements===
|
||||
|
||||
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS be Chinese (Simplified)
|
||||
- All section content must be in Chinese
|
||||
- Section headers must use the specified Chinese format: 【基本介绍】【性格特点】【核心价值观】【一句话总结】
|
||||
|
||||
**FORMAT REQUIREMENT:**
|
||||
- Each section must start with its header on a new line
|
||||
- Content follows immediately after the header
|
||||
- Sections are separated by blank lines
|
||||
- Strictly adhere to character limits for each section
|
||||
- **DO NOT include any text after the 【一句话总结】 section**
|
||||
- **DO NOT output reflection steps, self-review, or verification notes**
|
||||
@@ -54,9 +54,9 @@ async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
|
||||
Returns:
|
||||
符合反思范围的记忆数据列表。
|
||||
"""
|
||||
if REFLEXION_RANGE == "partial":
|
||||
if REFLEXION_RANGE == "retrieval":
|
||||
return await get_data(host_id)
|
||||
elif REFLEXION_RANGE == "all":
|
||||
elif REFLEXION_RANGE == "database":
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
|
||||
|
||||
@@ -64,8 +64,8 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
|
||||
|
||||
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None, **kwargs):
|
||||
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
|
||||
app_id = os.environ.get("TEXTLN_APP_ID", "")
|
||||
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
|
||||
app_id = os.environ.get("TEXTLN_APP_ID", "fa3f24380683ad53e6c620c0f0878a09")
|
||||
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "6130caac9aabc6eb26433758d7898f4a")
|
||||
pdf_parser = TextLnParser(textln_api=textln_api, app_id=app_id, secret_code=secret_code)
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
@@ -672,7 +672,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel"):
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
parser_config["chunk_token_num"] = 0
|
||||
else:
|
||||
sections = [(_, "") for _ in excel_parser(binary) if _]
|
||||
parser_config["chunk_token_num"] = 12800
|
||||
|
||||
@@ -5,7 +5,6 @@ from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.nlp import find_codec
|
||||
|
||||
@@ -29,7 +28,7 @@ class RAGExcelParser:
|
||||
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||
df = pd.read_csv(file_like_object)
|
||||
return RAGExcelParser._dataframe_to_workbook(df)
|
||||
|
||||
except Exception as e_csv:
|
||||
@@ -66,6 +65,7 @@ class RAGExcelParser:
|
||||
# if contains multiple sheets use _dataframes_to_workbook
|
||||
if isinstance(df, dict) and len(df) > 1:
|
||||
return RAGExcelParser._dataframes_to_workbook(df)
|
||||
|
||||
df = RAGExcelParser._clean_dataframe(df)
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
@@ -77,14 +77,15 @@ class RAGExcelParser:
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
return wb
|
||||
|
||||
return wb
|
||||
|
||||
@staticmethod
|
||||
def _dataframes_to_workbook(dfs: dict):
|
||||
wb = Workbook()
|
||||
default_sheet = wb.active
|
||||
wb.remove(default_sheet)
|
||||
|
||||
|
||||
for sheet_name, df in dfs.items():
|
||||
df = RAGExcelParser._clean_dataframe(df)
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
@@ -95,52 +96,6 @@ class RAGExcelParser:
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
return wb
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_worksheet(ws, sheetname=None):
|
||||
"""
|
||||
Extract images from a worksheet and enrich them with vision-based descriptions.
|
||||
|
||||
Returns: List[dict]
|
||||
"""
|
||||
images = getattr(ws, "_images", [])
|
||||
if not images:
|
||||
return []
|
||||
|
||||
raw_items = []
|
||||
|
||||
for img in images:
|
||||
try:
|
||||
img_bytes = img._data()
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
|
||||
|
||||
anchor = img.anchor
|
||||
if hasattr(anchor, "_from") and hasattr(anchor, "_to"):
|
||||
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
|
||||
r2, c2 = anchor._to.row + 1, anchor._to.col + 1
|
||||
if r1 == r2 and c1 == c2:
|
||||
span = "single_cell"
|
||||
else:
|
||||
span = "multi_cell"
|
||||
else:
|
||||
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
|
||||
r2, c2 = r1, c1
|
||||
span = "single_cell"
|
||||
|
||||
item = {
|
||||
"sheet": sheetname or ws.title,
|
||||
"image": pil_img,
|
||||
"image_description": "",
|
||||
"row_from": r1,
|
||||
"col_from": c1,
|
||||
"row_to": r2,
|
||||
"col_to": c2,
|
||||
"span_type": span,
|
||||
}
|
||||
raw_items.append(item)
|
||||
except Exception:
|
||||
continue
|
||||
return raw_items
|
||||
|
||||
def html(self, fnm, chunk_rows=256):
|
||||
from html import escape
|
||||
|
||||
@@ -173,7 +128,7 @@ class RAGExcelParser:
|
||||
tb = ""
|
||||
tb += f"<table><caption>{sheetname}</caption>"
|
||||
tb += tb_rows_0
|
||||
for r in list(rows[1 + chunk_i * chunk_rows: min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
|
||||
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
|
||||
tb += "<tr>"
|
||||
for i, c in enumerate(r):
|
||||
if c.value is None:
|
||||
@@ -196,7 +151,7 @@ class RAGExcelParser:
|
||||
except Exception as e:
|
||||
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||
df = pd.read_csv(file_like_object)
|
||||
df = df.replace(r"^\s*$", "", regex=True)
|
||||
return df.to_markdown(index=False)
|
||||
|
||||
@@ -214,35 +169,19 @@ class RAGExcelParser:
|
||||
continue
|
||||
if not rows:
|
||||
continue
|
||||
# 获取表头
|
||||
ti = list(rows[0])
|
||||
header_fields = []
|
||||
for cell in ti:
|
||||
if cell.value: # 只添加有值的表头
|
||||
header_fields.append(str(cell.value))
|
||||
|
||||
# 如果有数据行,处理数据行;否则只处理表头
|
||||
data_rows = rows[1:]
|
||||
if data_rows:
|
||||
for r in data_rows:
|
||||
fields = []
|
||||
for i, c in enumerate(r):
|
||||
if not c.value:
|
||||
continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
else:
|
||||
# 只有表头的情况
|
||||
if header_fields:
|
||||
line = "; ".join(header_fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
for r in list(rows[1:]):
|
||||
fields = []
|
||||
for i, c in enumerate(r):
|
||||
if not c.value:
|
||||
continue
|
||||
t = str(ti[i].value) if i < len(ti) else ""
|
||||
t += (":" if t else "") + str(c.value)
|
||||
fields.append(t)
|
||||
line = "; ".join(fields)
|
||||
if sheetname.lower().find("sheet") < 0:
|
||||
line += " ——" + sheetname
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
@@ -250,14 +189,14 @@ class RAGExcelParser:
|
||||
if fnm.split(".")[-1].lower().find("xls") >= 0:
|
||||
wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
|
||||
|
||||
for sheetname in wb.sheetnames:
|
||||
try:
|
||||
ws = wb[sheetname]
|
||||
total += len(list(ws.rows))
|
||||
except Exception as e:
|
||||
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
|
||||
continue
|
||||
try:
|
||||
ws = wb[sheetname]
|
||||
total += len(list(ws.rows))
|
||||
except Exception as e:
|
||||
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
|
||||
continue
|
||||
return total
|
||||
|
||||
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
|
||||
|
||||
@@ -196,7 +196,7 @@ class EntityResolution(Extractor):
|
||||
ans_list = []
|
||||
records = [r.strip() for r in results.split(record_delimiter)]
|
||||
for record in records:
|
||||
pattern_int = fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
||||
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
||||
match_int = re.search(pattern_int, record)
|
||||
res_int = int(str(match_int.group(1) if match_int else '0'))
|
||||
if res_int > records_length:
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
import json_repair
|
||||
import pandas as pd
|
||||
import time
|
||||
import trio
|
||||
|
||||
from app.core.rag.common.misc_utils import get_uuid
|
||||
@@ -263,21 +262,21 @@ class KGSearch(Dealer):
|
||||
relas = ""
|
||||
|
||||
return {
|
||||
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token),
|
||||
"vector": None,
|
||||
"metadata": {
|
||||
"doc_id": get_uuid(),
|
||||
"file_id": "",
|
||||
"file_name": "Related content in Knowledge Graph",
|
||||
"file_created_at": int(time.time() * 1000),
|
||||
"chunk_id": get_uuid(),
|
||||
"content_ltks": "",
|
||||
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
|
||||
comm_topn, max_token),
|
||||
"document_id": "",
|
||||
"knowledge_id": kb_ids,
|
||||
"sort_id": 0,
|
||||
"status": 1,
|
||||
"score": 1
|
||||
},
|
||||
"children": None
|
||||
}
|
||||
"docnm_kwd": "Related content in Knowledge Graph",
|
||||
"kb_id": kb_ids,
|
||||
"important_kwd": [],
|
||||
"image_id": "",
|
||||
"similarity": 1.,
|
||||
"vector_similarity": 1.,
|
||||
"term_similarity": 0,
|
||||
"vector": [],
|
||||
"positions": [],
|
||||
}
|
||||
|
||||
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
|
||||
## Community retrieval
|
||||
|
||||
@@ -448,7 +448,7 @@ if __name__ == "__main__":
|
||||
# 准备配置vision_model信息
|
||||
# 初始化 QWenCV
|
||||
vision_model = QWenCV(
|
||||
key="",
|
||||
key="sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
model_name="qwen-vl-max",
|
||||
lang="Chinese", # 默认使用中文
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
@@ -26,8 +26,6 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
|
||||
from app.core.rag.common.string_utils import remove_redundant_spaces
|
||||
from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
@@ -50,7 +48,6 @@ def knowledge_retrieval(
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
- use_graph: bool, whether to use a graph
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
@@ -62,7 +59,6 @@ def knowledge_retrieval(
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
||||
|
||||
file_names_filter = []
|
||||
if user_ids:
|
||||
@@ -71,10 +67,6 @@ def knowledge_retrieval(
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
kb_ids = []
|
||||
workspace_ids = []
|
||||
chat_model = None
|
||||
embedding_model = None
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
@@ -95,22 +87,6 @@ def knowledge_retrieval(
|
||||
else:
|
||||
continue
|
||||
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
@@ -160,12 +136,6 @@ def knowledge_retrieval(
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
# use graph
|
||||
if use_graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection):
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
# similarity=similarity
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bqry and rank_feature:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""工具管理核心模块"""
|
||||
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
from .base import BaseTool, ToolResult, ToolParameter
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
|
||||
# 可选导入,避免导入错误
|
||||
try:
|
||||
|
||||
@@ -191,14 +191,10 @@ class BaseTool(ABC):
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self, operation: Optional[str] = None):
|
||||
"""转换为Langchain工具格式
|
||||
|
||||
Args:
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
"""
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self, operation)
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
@@ -1,11 +1,11 @@
|
||||
"""内置工具模块"""
|
||||
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from app.core.tools.builtin.datetime_tool import DateTimeTool
|
||||
from app.core.tools.builtin.json_tool import JsonTool
|
||||
from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool
|
||||
from app.core.tools.builtin.mineru_tool import MinerUTool
|
||||
from app.core.tools.builtin.textin_tool import TextInTool
|
||||
from .base import BuiltinTool
|
||||
from .datetime_tool import DateTimeTool
|
||||
from .json_tool import JsonTool
|
||||
from .baidu_search_tool import BaiduSearchTool
|
||||
from .mineru_tool import MinerUTool
|
||||
from .textin_tool import TextInTool
|
||||
|
||||
__all__ = [
|
||||
"BuiltinTool",
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class BaiduSearchTool(BuiltinTool):
|
||||
@@ -110,7 +110,7 @@ class BaiduSearchTool(BuiltinTool):
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result["results"],
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
import pytz
|
||||
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class DateTimeTool(BuiltinTool):
|
||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
@@ -95,7 +95,7 @@ class DateTimeTool(BuiltinTool):
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result["result_data"],
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
@@ -123,14 +123,12 @@ class DateTimeTool(BuiltinTool):
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": now.isoformat(),
|
||||
"result_data": {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timestamp_ms": int(now.timestamp() * 1000),
|
||||
"utc_datetime": utc_now.strftime(output_format),
|
||||
}
|
||||
"timestamp_ms": int(now.timestamp() * 1000),
|
||||
"utc_datetime": utc_now.strftime(output_format)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -150,8 +148,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"original": input_value,
|
||||
"formatted": dt.strftime(output_format),
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": dt.strftime(output_format)
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -192,8 +189,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"original_timezone": from_timezone,
|
||||
"converted": converted_dt.strftime(output_format),
|
||||
"converted_timezone": to_timezone,
|
||||
"timestamp": int(converted_dt.timestamp()),
|
||||
"result_data": converted_dt.strftime(output_format)
|
||||
"timestamp": int(converted_dt.timestamp())
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -223,8 +219,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"timestamp": timestamp,
|
||||
"datetime": dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": dt.strftime(output_format)
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -254,8 +249,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat(),
|
||||
"result_data": int(dt.timestamp())
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
@@ -293,8 +287,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"calculation": calculation,
|
||||
"result": calculated_dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(calculated_dt.timestamp()),
|
||||
"result_data": calculated_dt.strftime(output_format)
|
||||
"timestamp": int(calculated_dt.timestamp())
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class JsonTool(BuiltinTool):
|
||||
@@ -29,7 +29,8 @@ class JsonTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["insert", "replace", "delete", "parse"]
|
||||
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
|
||||
"extract", "insert", "replace", "delete", "parse"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
@@ -69,7 +70,7 @@ class JsonTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(用于insert、replace、delete、parse操作,如:$.user.name或users[0].name)",
|
||||
description="JSON路径表达式(用于extract、insert、replace、delete、parse操作,如:$.user.name或users[0].name)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
@@ -136,7 +137,7 @@ class JsonTool(BuiltinTool):
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result["result_data"],
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
@@ -671,8 +672,7 @@ class JsonTool(BuiltinTool):
|
||||
"success": True,
|
||||
"value": current,
|
||||
"value_type": type(current).__name__,
|
||||
"value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current),
|
||||
"result_data": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current)
|
||||
"value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current)
|
||||
}
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
@@ -681,8 +681,7 @@ class JsonTool(BuiltinTool):
|
||||
"json_path": json_path,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"value": None,
|
||||
"result_data": None
|
||||
"value": None
|
||||
}
|
||||
|
||||
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class MinerUTool(BuiltinTool):
|
||||
|
||||
@@ -1,216 +0,0 @@
|
||||
"""操作工具 - 为特定操作创建的工具包装器"""
|
||||
from typing import List
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.models import ToolType
|
||||
|
||||
|
||||
class OperationTool(BaseTool):
|
||||
"""操作工具 - 包装基础工具的特定操作"""
|
||||
|
||||
def __init__(self, base_tool: BaseTool, operation: str):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"{self.base_tool.name}_{self.operation}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self.base_tool.description} - {self.operation}"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""返回特定操作的参数"""
|
||||
if self.base_tool.name == 'datetime_tool':
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
|
||||
def _get_datetime_params(self) -> List[ToolParameter]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if self.operation == "now":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "format":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "convert_timezone":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "timestamp_to_datetime":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_json_params(self) -> List[ToolParameter]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
if self.operation == "insert":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_value",
|
||||
type=ParameterType.STRING,
|
||||
description="新值(用于insert操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "replace":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="old_text",
|
||||
type=ParameterType.STRING,
|
||||
description="要替换的原文本(用于replace操作)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_text",
|
||||
type=ParameterType.STRING,
|
||||
description="替换后的新文本(用于replace操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "delete":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "parse":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return base_params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
kwargs["operation"] = self.operation
|
||||
return await self.base_tool.execute(**kwargs)
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.builtin.base import BuiltinTool
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class TextInTool(BuiltinTool):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""自定义工具模块"""
|
||||
|
||||
from app.core.tools.custom.base import CustomTool
|
||||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||
from app.core.tools.custom.auth_manager import AuthManager
|
||||
from .base import CustomTool
|
||||
from .schema_parser import OpenAPISchemaParser
|
||||
from .auth_manager import AuthManager
|
||||
|
||||
__all__ = [
|
||||
"CustomTool",
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""自定义工具基类"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
@@ -136,13 +135,6 @@ class CustomTool(BaseTool):
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
if isinstance(self.schema_content, str):
|
||||
try:
|
||||
self.schema_content = json.loads(self.schema_content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
|
||||
@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -210,9 +213,7 @@ class OpenAPISchemaParser:
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
summary = operation.get("summary", "")
|
||||
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
@@ -222,7 +223,7 @@ class OpenAPISchemaParser:
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": summary if summary else operation_id,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
|
||||
@@ -21,35 +21,24 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
|
||||
# 内部工具实例
|
||||
tool_instance: BaseTool = Field(..., description="内部工具实例")
|
||||
# 特定操作(用于自定义工具)
|
||||
operation: Optional[str] = Field(None, description="特定操作")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs):
|
||||
def __init__(self, tool_instance: BaseTool, **kwargs):
|
||||
"""初始化Langchain工具包装器
|
||||
|
||||
Args:
|
||||
tool_instance: 内部工具实例
|
||||
operation: 特定操作(用于自定义工具)
|
||||
"""
|
||||
# 动态创建参数schema
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(
|
||||
tool_instance.parameters, operation
|
||||
)
|
||||
|
||||
# 构建工具名称
|
||||
tool_name = tool_instance.name
|
||||
if operation:
|
||||
tool_name = f"{tool_instance.name}_{operation}"
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
|
||||
|
||||
super().__init__(
|
||||
name=tool_name,
|
||||
name=tool_instance.name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
tool_instance=tool_instance,
|
||||
operation=operation,
|
||||
_tool_instance=tool_instance,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -69,12 +58,8 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
) -> str:
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 如果有特定操作,添加到参数中
|
||||
if self.operation:
|
||||
kwargs["operation"] = self.operation
|
||||
|
||||
# 执行内部工具
|
||||
result = await self.tool_instance.safe_execute(**kwargs)
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
@@ -88,82 +73,24 @@ class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
operation: 特定操作(适用于有操作的工具)或MCP工具名称
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
# 处理MCP工具的特定工具名称
|
||||
if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation:
|
||||
# 为MCP工具创建特定工具名称的实例
|
||||
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
|
||||
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
elif operation and LangchainAdapter._tool_supports_operations(tool):
|
||||
# 为支持多操作的工具创建特定操作实例
|
||||
if tool.tool_type.value == "custom":
|
||||
# 自定义工具直接传递operation参数
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation)
|
||||
else:
|
||||
# 内置工具使用OperationTool包装
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
else:
|
||||
# 单个工具
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _tool_supports_operations(tool: BaseTool) -> bool:
|
||||
"""检查工具是否支持多操作"""
|
||||
# 内置工具中支持操作的工具
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool']
|
||||
|
||||
# 检查内置工具
|
||||
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
|
||||
return True
|
||||
|
||||
# 检查自定义工具(自定义工具通过解析OpenAPI schema支持多操作)
|
||||
if tool.tool_type.value == "custom":
|
||||
# 检查工具是否有多个操作
|
||||
if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1:
|
||||
return True
|
||||
# 或者检查参数中是否有operation参数
|
||||
for param in tool.parameters:
|
||||
if param.name == "operation" and param.enum:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
if base_tool.tool_type.value == "builtin":
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
else:
|
||||
raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}")
|
||||
|
||||
@staticmethod
|
||||
def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||
"""为MCP工具创建指定工具名称的实例"""
|
||||
mcp_tool.set_current_tool(tool_name)
|
||||
return mcp_tool
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
@@ -183,19 +110,15 @@ class LangchainAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_pydantic_schema(
|
||||
parameters: List[ToolParameter],
|
||||
operation: Optional[str] = None
|
||||
) -> Type[BaseModel]:
|
||||
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
|
||||
"""根据工具参数创建Pydantic schema
|
||||
|
||||
Args:
|
||||
parameters: 工具参数列表
|
||||
operation: 特定操作(用于过滤参数)
|
||||
|
||||
Returns:
|
||||
Pydantic模型类
|
||||
@@ -204,12 +127,7 @@ class LangchainAdapter:
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
# 如果指定了operation,过滤掉operation参数
|
||||
filtered_params = parameters
|
||||
if operation:
|
||||
filtered_params = [p for p in parameters if p.name != "operation"]
|
||||
|
||||
for param in filtered_params:
|
||||
for param in parameters:
|
||||
# 确定Python类型
|
||||
python_type = LangchainAdapter._get_python_type(param.type)
|
||||
|
||||
@@ -232,7 +150,7 @@ class LangchainAdapter:
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
@@ -241,7 +159,7 @@ class LangchainAdapter:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["pattern"] = param.pattern
|
||||
field_kwargs["regex"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
@@ -251,10 +169,9 @@ class LangchainAdapter:
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__module__": __name__,
|
||||
"__annotations__": annotations,
|
||||
"model_config": {"extra": "forbid"},
|
||||
**fields
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
"""MCP 工具模块 - Model Context Protocol 支持"""
|
||||
"""MCP工具模块"""
|
||||
|
||||
# 主要类导出
|
||||
from .base import MCPTool, MCPToolManager, MCPError
|
||||
from .client import SimpleMCPClient, MCPConnectionError
|
||||
from .base import MCPTool
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"MCPTool",
|
||||
"MCPToolManager",
|
||||
"MCPError",
|
||||
|
||||
# 客户端类
|
||||
"SimpleMCPClient",
|
||||
"MCPConnectionError",
|
||||
|
||||
# 服务管理(简化版)
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPServiceManager"
|
||||
]
|
||||
@@ -1,9 +1,11 @@
|
||||
"""MCP工具基类 - 整合版本"""
|
||||
"""MCP工具基类"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -13,188 +15,215 @@ class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""根据工具名称返回对应参数"""
|
||||
# 如果有指定的工具名称,从 available_tools 中获取参数
|
||||
tool_name = getattr(self, '_current_tool_name', None)
|
||||
if tool_name and self.available_tools:
|
||||
for tool_info in self.available_tools:
|
||||
if tool_info.get("tool_name") == tool_name:
|
||||
arguments = tool_info.get("arguments", {})
|
||||
return self._generate_parameters_from_schema(arguments)
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 默认返回通用参数
|
||||
return [
|
||||
ToolParameter(
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要执行的工具名称",
|
||||
required=True
|
||||
),
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数",
|
||||
description="工具参数(JSON对象)",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]:
|
||||
"""从参数schema生成参数列表"""
|
||||
properties = arguments.get("properties", {})
|
||||
required_fields = arguments.get("required", [])
|
||||
|
||||
params = []
|
||||
for param_name, param_def in properties.items():
|
||||
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
|
||||
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=param_def.get("description", f"参数: {param_name}"),
|
||||
required=param_name in required_fields,
|
||||
default=param_def.get("default"),
|
||||
enum=param_def.get("enum"),
|
||||
minimum=param_def.get("minimum"),
|
||||
maximum=param_def.get("maximum")
|
||||
))
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
|
||||
"""转换JSON Schema类型到ParameterType"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(json_type, ParameterType.STRING)
|
||||
|
||||
def set_current_tool(self, tool_name: str):
|
||||
"""设置当前工具名称,用于获取特定参数"""
|
||||
self._current_tool_name = tool_name
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise Exception("未指定工具名称")
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
from .client import SimpleMCPClient
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
|
||||
client = SimpleMCPClient(self.server_url, self.connection_config)
|
||||
|
||||
async with client:
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}")
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_EXECUTION_ERROR",
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""MCP 错误基类"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolManager:
|
||||
"""MCP 工具管理器 - 简化版本"""
|
||||
|
||||
def __init__(self, db=None):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
|
||||
|
||||
async def discover_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> tuple[bool, List[Dict[str, Any]], str | None]:
|
||||
"""发现 MCP 服务器上的工具"""
|
||||
try:
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 缓存工具信息
|
||||
self._tool_cache[server_url] = {
|
||||
"tools": tools,
|
||||
"connection_config": connection_config,
|
||||
"last_updated": time.time()
|
||||
}
|
||||
|
||||
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
|
||||
return True, tools, None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"发现工具失败: {e}"
|
||||
logger.error(error_msg)
|
||||
return False, [], error_msg
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
from .client import SimpleMCPClient
|
||||
from .client import MCPClient
|
||||
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tools_count": len(tools),
|
||||
"tools": [tool.get("name") for tool in tools],
|
||||
"message": "连接成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "连接失败"
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
"""MCP客户端 - 简化版本"""
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
@@ -17,260 +18,139 @@ class MCPConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleMCPClient:
|
||||
"""简化的 MCP 客户端"""
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
self.timeout = self.connection_config.get("timeout", 30)
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
self.is_sse = "/sse" in server_url.lower()
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests = {}
|
||||
self._server_capabilities = {}
|
||||
self._endpoint_url = None # SSE endpoint URL
|
||||
self._sse_task = None
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""建立连接"""
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
try:
|
||||
if self.is_websocket:
|
||||
await self._connect_websocket()
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
else:
|
||||
await self._connect_http()
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||
raise MCPConnectionError(f"连接失败: {e}")
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
try:
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
if self._websocket:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
if self._session:
|
||||
elif self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||
|
||||
if self.is_sse:
|
||||
await self._initialize_sse_session()
|
||||
elif "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_sse_session(self):
|
||||
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||
try:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
# 启动 SSE 读取任务
|
||||
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
|
||||
|
||||
# 等待获取 endpoint URL
|
||||
for _ in range(10):
|
||||
if self._endpoint_url:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("未能获取 endpoint URL")
|
||||
|
||||
# 发送 initialize 请求到 endpoint
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
init_response = await self._send_sse_request(init_request)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
result = init_response.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
async def _read_sse_stream(self, response):
|
||||
"""读取 SSE 流"""
|
||||
try:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
|
||||
if line.startswith('event:'):
|
||||
continue
|
||||
|
||||
if line.startswith('data:'):
|
||||
data = line[5:].strip() # 去除 'data:' 后的空格
|
||||
if not data or data == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
# 处理 endpoint 事件(相对路径或绝对路径)
|
||||
if not self._endpoint_url:
|
||||
# 如果是相对路径,拼接成完整 URL
|
||||
if data.startswith('/'):
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
parsed = urlparse(self.server_url)
|
||||
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
|
||||
else:
|
||||
self._endpoint_url = data
|
||||
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
|
||||
continue
|
||||
|
||||
# 处理 message 事件
|
||||
message = json.loads(data)
|
||||
request_id = message.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"SSE 流读取错误: {e}")
|
||||
|
||||
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通过 SSE endpoint 发送请求"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
request_id = request["id"]
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await asyncio.wait_for(future, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError("请求超时")
|
||||
|
||||
async def _send_sse_notification(self, notification: Dict[str, Any]):
|
||||
"""发送通知(无需响应)"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
init_response = await response.json()
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(self.server_url, json=initialized_notification):
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
# 基础 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 合并 connection_config 中的自定义 headers
|
||||
custom_headers = self.connection_config.get("headers", {})
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# 处理认证配置(认证 headers 优先级更高)
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
|
||||
if auth_type == "bearer_token":
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif auth_type == "api_key":
|
||||
key = auth_config.get("api_key")
|
||||
header_name = auth_config.get("key_name", "X-API-Key")
|
||||
if key:
|
||||
headers[header_name] = key
|
||||
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
@@ -281,99 +161,504 @@ class SimpleMCPClient:
|
||||
|
||||
return headers
|
||||
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息(WebSocket)"""
|
||||
init_message = {
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
|
||||
|
||||
result = response_data.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
await self._websocket.send(json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}))
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request = {
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||
|
||||
result = response_data.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response_data.get("result", {})
|
||||
|
||||
def _get_request_id(self) -> int:
|
||||
"""生成请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
data = json.loads(message)
|
||||
request_id = data.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
except ConnectionClosed:
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
if self.connection_type == "websocket":
|
||||
request_id = str(request_data["id"])
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
|
||||
# 创建Future等待响应
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 消息处理错误: {e}")
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as root_response:
|
||||
if root_response.status != 200:
|
||||
error_text = await root_response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user