Compare commits

..

1 Commits

Author SHA1 Message Date
Ke Sun
5b13b4a949 GitHub release (#20)
* feat(web): remove mock data
* feat(knowledgeBase): Refactor document list API and improve polling logic

- Update getDocumentList API to accept kb_id as separate parameter instead of extracting from query object
- Fix parameter name from auto_question to auto_questions in parser config
- Add progress field initialization in document update params
- Improve polling logic to handle both auto-return and manual stay scenarios with proper loading state management
- Add console logging for debugging polling status and document processing
- Reduce polling interval from 5000ms to 3000ms for faster status updates
- Enhance cleanup logic with route change detection to prevent memory leaks
- Add record parameter to progress render function for better data access
- Refactor confirm dialog callbacks to properly manage loading state timing
- Ensure loading indicator displays correctly when user chooses to stay on page

* feat(web): Add Workflow

* feat(web): Workflow

* feat(web): node show id; update reflection engine example

* feat(components): Add markdown editing capability and enhance component styling

- Add editable mode to Markdown component with edit/save/cancel buttons
- Import EditOutlined, SaveOutlined, CloseOutlined icons from ant-design
- Add useState, useRef, useEffect hooks for managing edit state
- Add editable, onContentChange, and onSave props to RbMarkdownProps interface
- Create RbModal component with new index.css stylesheet for modal styling
- Add index.css stylesheet to KnowledgeBase components for consistent styling
- Update i18n translations in en.ts and zh.ts for new UI elements
- Refactor Markdown component handlers to accept and spread additional props
- Update InsertModal and RecallTestResult components for improved UX
- Fix prop spreading in component handlers to maintain compatibility with Ant Design components

* feat(web): Graph user memory update

* feat(web): update routes.json

* fix(web): workflow bug

* fix(web): workflow variable

* fix(web): workflow properties

* feat(web): workflow support lexical editor

* feat(web): workflow support lexical editor

* feat(web): update reflection engine result

* feat(web): workflow's chat support abort output

* fix:git commit

* fix:vite config

* fix:breadcrumbs

* feat(i18n): add document processing confirmation dialog translations

- Add "processingDocuments" translation key for loading state message in English and Chinese
- Add "startUploadConfirmTitle" translation for confirmation dialog title
- Add "startUploadConfirmContent" translation for confirmation dialog description
- Add "returnToList" translation for returning to list page action
- Add "stayOnPage" translation for staying on current page action
- Support user choice to either return to list or stay on page during background document processing

* fix(web): user memory detail

* feat(web): order

* fix:面包屑修改

* feat(web): 1. user memory; 2. update workspace's model config

* feat(web): update zh.ts / en.ts

* fix(web): update user profile

* feat(web): Agent add ai prompt

* feat(web): Agent add ai prompt

* feat(web): add pricing menu

* feat(knowledgeBase): add media file validation and PDF enhancement method selection

- Add i18n translations for file size and duration validation errors in English and Chinese
- Implement media file validation with 256MB size limit and 150-second duration limit
- Add support for audio and video file formats (mp3, mp4, mov, wav) in dataset creation
- Add checkMediaDuration helper function to validate media file duration using HTML5 media API
- Add PDF enhancement method selection dropdown with options (DeepDoc, MinerU, TextLN)
- Change default PDF enhancement setting from disabled to enabled
- Update file type array to include media formats
- Add error messaging for file size and duration validation failures
- Improve UI spacing for file parsing settings section

* feat(knowledgeBase): add media dataset support and improve file handling

- Add media dataset translations in English and Chinese locales
- Add "mediaDataSet" and "uploadMedia" i18n keys for UI labels
- Enable media dataset creation option in Private component by uncommenting menu item
- Import and display image icon for media dataset menu option
- Refactor file ID handling in CreateDataset to support both string and array types
- Improve fileIds initialization logic to handle mixed input types
- Update CreateImageDataset component to use file chunking workflow
- Add navigation to parameter settings step after file upload
- Pass file IDs to dataset creation flow for media processing
- Add message API and navigate hook for improved UX feedback

* fix(knowledgeBase): improve navigation and folder tree refresh logic

- Add path comparison check in breadcrumb navigation to avoid unnecessary route changes when already on target page
- Implement delayed folder tree refresh with setTimeout to ensure state reset completes before refreshing
- Add manual table refresh trigger to ensure data updates after navigation
- Reset expanded keys in FolderTree component during load to ensure consistent state from root directory
- Add expanded keys reset in breadcrumb navigation to prevent stale expansion state
- Improve navigation state handling by using replace flag only when on target path to reduce history stack pollution

* fix:pdfEnhancementEnabled

* feat(web): add tool management

* fix(web): get the parent domain name adaptation IP

* fix(web): Conversation add initialValue

* feat(web): workflow’s Editor Variable support Tag

* fix(web): pricing UI

* feat(web): JSON Tool update

* fix(web): update get llm,chat model list function

* fix(web): time tool / cluster chat

* fix(web): time tool add time zone

* feat(web): neo4j type user memory detail

* fix(web): update parseSchema api param

* feat: workflow add knowledge-retrieval node

* feat(knowledgeBase): enhance file upload and dataset creation with abort support and improved UX

- Add AbortSignal support to uploadFile API for cancellable uploads
- Implement custom onRemove callback in UploadFiles component with confirmation dialog
- Add i18n translations for file removal confirmation and error messages
- Update supported file types documentation to include IMAGE and MEDIA formats
- Improve file removal UI with cursor pointer styling
- Refactor getModelList API to remove unused type parameter
- Add Form import and UploadFile type for better type safety in CreateDataset
- Enhance error handling and user feedback for file operations

* feat(web): MCP add bearer token auth type

* fix(web): UI update

---------

Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yujiangping <yujiangping@taofen8.com>
Co-authored-by: 赵莹 <zhaoying@redbearai.com>
Co-authored-by: vrhs@163.com <accounts_660b6454a0eb398d3f8d2c76@mail.teambition.com>
2025-12-30 18:37:40 +08:00
459 changed files with 7429 additions and 39252 deletions

View File

@@ -1,25 +0,0 @@
from pydantic import BaseModel, Field
from sqlalchemy import TypeDecorator, JSON
class PydanticType(TypeDecorator):
impl = JSON
def __init__(self, pydantic_model: type[BaseModel]):
super().__init__()
self.model = pydantic_model
def process_bind_param(self, value, dialect):
# 入库Model -> dict
if value is None:
return None
if isinstance(value, self.model):
return value.dict()
return value # 已经是 dict 也放行
def process_result_value(self, value, dialect):
# 出库dict -> Model
if value is None:
return None
# return self.model.parse_obj(value) # pydantic v1
return self.model.model_validate(value) # pydantic v2

View File

@@ -85,8 +85,6 @@ health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置
beat_schedule_config = {
@@ -105,13 +103,6 @@ beat_schedule_config = {
"schedule": memory_cache_regeneration_schedule,
"args": (),
},
"run-forgetting-cycle": {
"task": "app.tasks.run_forgetting_cycle_task",
"schedule": forgetting_cycle_schedule,
"kwargs": {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
}
# 如果配置了默认工作空间ID则添加记忆总量统计任务

View File

@@ -4,47 +4,37 @@
认证方式: JWT Token
"""
from fastapi import APIRouter
from . import (
api_key_controller,
app_controller,
auth_controller,
chunk_controller,
document_controller,
emotion_config_controller,
emotion_controller,
file_controller,
home_page_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_episodic_controller,
memory_explicit_controller,
memory_forget_controller,
memory_reflection_controller,
memory_short_term_controller,
memory_storage_controller,
model_controller,
multi_agent_controller,
prompt_optimizer_controller,
public_share_controller,
release_share_controller,
setup_controller,
task_controller,
test_controller,
tool_controller,
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
auth_controller,
workspace_controller,
memory_forget_controller,
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
app_controller,
upload_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
memory_reflection_controller,
api_key_controller,
release_share_controller,
public_share_controller,
multi_agent_controller,
workflow_controller,
emotion_controller,
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
)
from . import user_memory_controllers
# 创建管理端 API 路由器
manager_router = APIRouter()
@@ -69,8 +59,6 @@ manager_router.include_router(memory_agent_controller.router)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(memory_storage_controller.router)
manager_router.include_router(user_memory_controllers.router)
manager_router.include_router(memory_episodic_controller.router)
manager_router.include_router(memory_explicit_controller.router)
manager_router.include_router(api_key_controller.router)
manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
@@ -81,12 +69,6 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router)
manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
__all__ = ["manager_router"]

View File

@@ -11,16 +11,15 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.models.app_model import AppType
from app.models.app_model import AppType, App
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
@@ -30,9 +29,9 @@ logger = get_business_logger()
@router.post("", summary="创建应用(可选创建 Agent 配置)")
@cur_workspace_access_guard()
def create_app(
payload: app_schema.AppCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
payload: app_schema.AppCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
@@ -42,34 +41,22 @@ def create_app(
@router.get("", summary="应用列表(分页)")
@cur_workspace_access_guard()
def list_apps(
type: str | None = None,
visibility: str | None = None,
status: str | None = None,
search: str | None = None,
include_shared: bool = True,
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
type: str | None = None,
visibility: str | None = None,
status: str | None = None,
search: str | None = None,
include_shared: bool = True,
page: int = 1,
pagesize: int = 10,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
# 正常分页查询
items_orm, total = app_service.list_apps(
db,
workspace_id=workspace_id,
@@ -82,17 +69,18 @@ def list_apps(
pagesize=pagesize,
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard()
def get_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取应用详细信息
@@ -111,10 +99,10 @@ def get_app(
@router.put("/{app_id}", summary="更新应用基本信息")
@cur_workspace_access_guard()
def update_app(
app_id: uuid.UUID,
payload: app_schema.AppUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
payload: app_schema.AppUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
@@ -124,9 +112,9 @@ def update_app(
@router.delete("/{app_id}", summary="删除应用")
@cur_workspace_access_guard()
def delete_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""删除应用
@@ -153,10 +141,10 @@ def delete_app(
@router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard()
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
new_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""复制应用(包括基础信息和配置)
@@ -190,10 +178,10 @@ def copy_app(
@router.put("/{app_id}/config", summary="更新 Agent 配置")
@cur_workspace_access_guard()
def update_agent_config(
app_id: uuid.UUID,
payload: app_schema.AgentConfigUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
payload: app_schema.AgentConfigUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
@@ -204,9 +192,9 @@ def update_agent_config(
@router.get("/{app_id}/config", summary="获取 Agent 配置")
@cur_workspace_access_guard()
def get_agent_config(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
@@ -218,10 +206,10 @@ def get_agent_config(
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard()
def publish_app(
app_id: uuid.UUID,
payload: app_schema.PublishRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
payload: app_schema.PublishRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.publish(
@@ -229,7 +217,7 @@ def publish_app(
app_id=app_id,
publisher_id=current_user.id,
workspace_id=workspace_id,
version_name=payload.version_name,
version_name = payload.version_name,
release_notes=payload.release_notes
)
return success(data=app_schema.AppRelease.model_validate(release))
@@ -238,9 +226,9 @@ def publish_app(
@router.get("/{app_id}/release", summary="获取当前发布版本")
@cur_workspace_access_guard()
def get_current_release(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
@@ -252,9 +240,9 @@ def get_current_release(
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
@cur_workspace_access_guard()
def list_releases(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
@@ -265,10 +253,10 @@ def list_releases(
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
@cur_workspace_access_guard()
def rollback(
app_id: uuid.UUID,
version: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
version: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
@@ -278,10 +266,10 @@ def rollback(
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
@cur_workspace_access_guard()
def share_app(
app_id: uuid.UUID,
payload: app_schema.AppShareCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
payload: app_schema.AppShareCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""分享应用到其他工作空间
@@ -306,10 +294,10 @@ def share_app(
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
@cur_workspace_access_guard()
def unshare_app(
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""取消应用分享
@@ -330,9 +318,9 @@ def unshare_app(
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard()
def list_app_shares(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用的所有分享记录
@@ -349,15 +337,14 @@ def list_app_shares(
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard()
async def draft_run(
app_id: uuid.UUID,
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
app_id: uuid.UUID,
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
):
"""
试运行 Agent使用当前的草稿配置未发布的配置
@@ -374,7 +361,7 @@ async def draft_run(
workspace_id=workspace_id,
user=current_user
)
if storage_type is None:
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
@@ -384,9 +371,10 @@ async def draft_run(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
if knowledge:
user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
@@ -406,22 +394,13 @@ async def draft_run(
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
)
payload.user_id = str(new_end_user.id)
# 处理会话ID创建或验证
conversation_id = await draft_service._ensure_conversation(
conversation_id=payload.conversation_id,
app_id=app_id,
workspace_id=workspace_id,
user_id=payload.user_id
)
conversation_id=payload.conversation_id,
app_id=app_id,
workspace_id=workspace_id,
user_id=payload.user_id
)
payload.conversation_id = conversation_id
if app.type == AppType.AGENT:
@@ -445,16 +424,17 @@ async def draft_run(
if payload.stream:
async def event_generator():
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -526,7 +506,7 @@ async def draft_run(
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
user_id=payload.user_id,
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
@@ -548,10 +528,10 @@ async def draft_run(
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -591,7 +571,7 @@ async def draft_run(
data=result,
msg="多 Agent 任务执行成功"
)
elif app.type == AppType.WORKFLOW: # 工作流
elif app.type == AppType.WORKFLOW: #工作流
config = workflow_service.check_config(app_id)
# 3. 流式返回
if payload.stream:
@@ -612,18 +592,17 @@ async def draft_run(
data: <json_data>
"""
import json
# 调用工作流服务的流式方法
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=current_user.current_workspace_id
config=config
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
@@ -648,7 +627,7 @@ async def draft_run(
}
)
result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
result = await workflow_service.run(app_id, payload,config)
logger.debug(
"工作流试运行返回结果",
@@ -663,13 +642,14 @@ async def draft_run(
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@cur_workspace_access_guard()
async def draft_run_compare(
app_id: uuid.UUID,
payload: app_schema.DraftRunCompareRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
app_id: uuid.UUID,
payload: app_schema.DraftRunCompareRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
多模型对比试运行
@@ -694,7 +674,7 @@ async def draft_run_compare(
workspace_id=workspace_id,
user=current_user
)
if storage_type is None:
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
@@ -703,7 +683,7 @@ async def draft_run_compare(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
if knowledge:
user_rag_memory_id = str(knowledge.id)
logger.info(
@@ -748,23 +728,9 @@ async def draft_run_compare(
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 获取 agent_cfg.model_parameters如果是 ModelParameters 对象则转为字典
agent_model_params = agent_cfg.model_parameters
if hasattr(agent_model_params, 'model_dump'):
agent_model_params = agent_model_params.model_dump()
elif not isinstance(agent_model_params, dict):
agent_model_params = {}
# 获取 model_item.model_parameters如果是 ModelParameters 对象则转为字典
item_model_params = model_item.model_parameters
if hasattr(item_model_params, 'model_dump'):
item_model_params = item_model_params.model_dump()
elif not isinstance(item_model_params, dict):
item_model_params = {}
merged_parameters = {
**(agent_model_params or {}),
**(item_model_params or {})
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
@@ -781,19 +747,19 @@ async def draft_run_compare(
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
):
yield event
@@ -855,15 +821,15 @@ async def get_workflow_config(
# 配置总是存在(不存在时返回默认模板)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
@cur_workspace_access_guard()
async def update_workflow_config(
app_id: uuid.UUID,
payload: WorkflowConfigUpdate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
app_id: uuid.UUID,
payload: WorkflowConfigUpdate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -1,28 +1,24 @@
import os
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.common.settings import kg_retriever
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.core.rag.llm.cv_model import QWenCV
from app.dependencies import get_current_user
from app.models import knowledge_model, knowledgeshare_model
from app.models.document_model import Document
from app.models.user_model import User
from app.models.document_model import Document
from app.models import knowledge_model, knowledgeshare_model
from app.core.rag.models.chunk import DocumentChunk
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -145,7 +141,7 @@ async def get_preview_chunks(
}
}
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
return success(data=jsonable_encoder(result), msg="Querying the document block preview list succeeded")
return success(data=result, msg="Querying the document block preview list succeeded")
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
@@ -203,7 +199,7 @@ async def get_chunks(
"has_next": True if page * pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of document chunk list succeeded")
return success(data=result, msg="Query of document chunk list succeeded")
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
@@ -264,7 +260,7 @@ async def create_chunk(
db_document.chunk_num += 1
db.commit()
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
return success(data=chunk, msg="Document chunk creation successful")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@@ -291,7 +287,7 @@ async def get_chunk(
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.get_by_segment(doc_id=doc_id)
if total:
return success(data=jsonable_encoder(items[0]), msg="Document chunk query successful")
return success(data=items[0], msg="Document chunk query successful")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -328,7 +324,7 @@ async def update_chunk(
chunk = items[0]
chunk.page_content = content
vector_service.update_by_segment(chunk)
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
return success(data=chunk, msg="The document chunk has been successfully updated")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -393,41 +389,36 @@ async def retrieve_chunks(
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
private_items = knowledge_service.get_chunked_knowledgeids(
existing_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
private_kb_ids = [item[0] for item in private_items]
private_workspace_ids = [item[1] for item in private_items]
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
items = knowledge_service.get_chunked_knowledgeids(
share_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
if items:
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
]
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters,
current_user=current_user
)
share_kb_ids = [item[0] for item in share_items]
share_workspace_ids = [item[1] for item in share_items]
private_kb_ids.extend(share_kb_ids)
private_workspace_ids.extend(share_workspace_ids)
if not private_kb_ids:
existing_ids.extend(items)
if not existing_ids:
return success(data=[], msg="retrieval successful")
kb_id = private_kb_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
@@ -457,21 +448,4 @@ async def retrieve_chunks(
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
rs.insert(0, doc)
return success(data=jsonable_encoder(rs), msg="retrieval successful")
return success(data=rs, msg="retrieval successful")

View File

@@ -1,26 +1,23 @@
import datetime
import os
from typing import Optional
import datetime
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.controllers import file_controller
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import document_model
from app.models.user_model import User
from app.models import document_model
from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import document_service, file_service, knowledge_service
from app.controllers import file_controller
from app.celery_app import celery_app
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -109,7 +106,7 @@ async def get_documents(
"has_next": True if page * pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of document list succeeded")
return success(data=result, msg="Query of document list succeeded")
@router.post("/document", response_model=ApiResponse)
@@ -127,7 +124,7 @@ async def create_document(
api_logger.debug(f"Start creating a document: {create_data.file_name}")
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document creation successful")
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
except Exception as e:
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
raise
@@ -156,7 +153,7 @@ async def get_document(
)
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Successfully obtained document information")
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
except HTTPException:
raise
except Exception as e:
@@ -224,7 +221,7 @@ async def update_document(
)
# 5. Return the updated document
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="Document information updated successfully")
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
@router.delete("/{document_id}", response_model=ApiResponse)

View File

@@ -18,7 +18,6 @@ from app.models.user_model import User
from app.schemas.emotion_schema import (
EmotionHealthRequest,
EmotionSuggestionsRequest,
EmotionGenerateSuggestionsRequest,
EmotionTagsRequest,
EmotionWordcloudRequest,
)
@@ -31,7 +30,7 @@ from sqlalchemy.orm import Session
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/emotion-memory",
prefix="/memory/emotion",
tags=["Emotion Analysis"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)
@@ -199,7 +198,7 @@ async def get_emotion_suggestions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议(从缓存读取)
"""获取个性化情绪建议
Args:
request: 包含 group_id 和可选的 config_id
@@ -207,72 +206,7 @@ async def get_emotion_suggestions(
current_user: 当前用户
Returns:
缓存的个性化情绪建议响应
"""
try:
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={
"group_id": request.group_id,
"config_id": request.config_id
}
)
# 从缓存获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.group_id,
db=db
)
if data is None:
# 缓存不存在或已过期
api_logger.info(
f"用户 {request.group_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id}
)
return fail(
BizCode.RESOURCE_NOT_FOUND,
"建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议",
None
)
api_logger.info(
"个性化建议获取成功(缓存)",
extra={
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功(缓存)")
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取个性化建议失败: {str(e)}"
)
@router.post("/generate_suggestions", response_model=ApiResponse)
async def generate_emotion_suggestions(
request: EmotionGenerateSuggestionsRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""生成个性化情绪建议调用LLM并缓存
Args:
request: 包含 group_id、可选的 config_id 和 force_refresh
db: 数据库会话
current_user: 当前用户
Returns:
新生成的个性化情绪建议响应
个性化情绪建议响应
"""
try:
# 验证 config_id如果提供
@@ -300,44 +234,36 @@ async def generate_emotion_suggestions(
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
api_logger.info(
f"用户 {current_user.username} 请求生成个性化情绪建议",
f"用户 {current_user.username} 请求获取个性化情绪建议",
extra={
"group_id": request.group_id,
"config_id": config_id
}
)
# 调用服务层生成建议
# 调用服务层
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id,
db=db
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.group_id,
suggestions_data=data,
db=db,
expires_hours=24
)
api_logger.info(
"个性化建议生成成功",
"个性化建议获取成功",
extra={
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议生成成功")
return success(data=data, msg="个性化建议获取成功")
except Exception as e:
api_logger.error(
f"生成个性化建议失败: {str(e)}",
f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成个性化建议失败: {str(e)}"
detail=f"获取个性化建议失败: {str(e)}"
)

View File

@@ -1,25 +1,22 @@
import os
from typing import Any, Optional
from pathlib import Path
import shutil
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import file_model
from app.models.user_model import User
from app.models import file_model
from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import file_service, document_service
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -96,11 +93,11 @@ async def get_files(
"has_next": True if page * pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
return success(data=result, msg="Query of file list succeeded")
@router.post("/folder", response_model=ApiResponse)
async def create_folder(
def create_folder(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
folder_name: str = '/',
@@ -124,7 +121,7 @@ async def create_folder(
)
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
raise
@@ -210,7 +207,7 @@ async def upload_file(
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="File upload successful")
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
@router.post("/customtext", response_model=ApiResponse)
@@ -291,7 +288,7 @@ async def custom_text(
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any)
@@ -365,7 +362,7 @@ async def update_file(
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items():
for field, value in update_data.items():
if hasattr(db_file, field):
old_value = getattr(db_file, field)
if old_value != value:
@@ -390,7 +387,7 @@ async def update_file(
)
# 4. Return the updated file
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
@router.delete("/{file_id}", response_model=ApiResponse)

View File

@@ -1,44 +0,0 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.home_page_service import HomePageService
router = APIRouter(prefix="/home-page", tags=["Home Page"])
@router.get("/statistics", response_model=ApiResponse)
def get_home_statistics(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取首页统计数据"""
statistics = HomePageService.get_home_statistics(db, current_user.tenant_id)
return success(data=statistics, msg="统计数据获取成功")
@router.get("/workspaces", response_model=ApiResponse)
def get_workspace_list(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号+说明"""
current_version = settings.SYSTEM_VERSION
version_info = HomePageService.load_version_introduction(current_version)
return success(
data={
"version": current_version,
"introduction": version_info.get("introduction"),
"introduction_en": version_info.get("introduction_en")
},
msg="系统版本获取成功"
)

View File

@@ -1,431 +0,0 @@
from datetime import datetime
from typing import Optional
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import (
cur_workspace_access_guard,
get_current_user,
)
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.implicit_memory_schema import GenerateProfileRequest
from app.services.implicit_memory_service import ImplicitMemoryService
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/implicit-memory",
tags=["Implicit Memory"],
)
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
"""
Centralized error handling for implicit memory operations.
Args:
e: The exception that occurred
operation: Description of the operation that failed
user_id: Optional user ID for logging context
Returns:
Standardized error response
"""
error_context = f"user_id={user_id}" if user_id else "unknown user"
if isinstance(e, ValueError):
if "user" in str(e).lower() and "not found" in str(e).lower():
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
else:
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
elif isinstance(e, KeyError):
api_logger.warning(f"Missing required data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
elif isinstance(e, (ConnectionError, TimeoutError)):
api_logger.error(f"Service unavailable for {operation}: {error_context}")
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
elif "storage" in str(e).lower() or "database" in str(e).lower():
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
else:
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
def validate_user_id(user_id: str) -> None:
"""
Validate user ID format and constraints.
Args:
user_id: User ID to validate
Raises:
ValueError: If user ID is invalid
"""
if not user_id or not user_id.strip():
raise ValueError("User ID cannot be empty")
if len(user_id.strip()) < 1:
raise ValueError("User ID is too short")
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
"""
Validate date range parameters.
Args:
start_date: Start date
end_date: End date
Raises:
ValueError: If date range is invalid
"""
if (start_date and not end_date) or (end_date and not start_date):
raise ValueError("Both start_date and end_date must be provided together")
if start_date and end_date and start_date >= end_date:
raise ValueError("start_date must be before end_date")
if start_date and start_date > datetime.now():
raise ValueError("start_date cannot be in the future")
def validate_confidence_threshold(threshold: float) -> None:
"""
Validate confidence threshold parameter.
Args:
threshold: Confidence threshold to validate
Raises:
ValueError: If threshold is invalid
"""
if not 0.0 <= threshold <= 1.0:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
end_date: Optional[datetime] = Query(None, description="Filter end date"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user preference tags from cache.
Args:
user_id: Target user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List of preference tags from cache
"""
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract preferences from cache
preferences = cached_profile.get("preferences", [])
# Apply filters (client-side filtering on cached data)
filtered_preferences = []
for pref in preferences:
# Filter by confidence threshold
if confidence_threshold is not None and pref.get("confidence_score", 0) < confidence_threshold:
continue
# Filter by category if specified
if tag_category and pref.get("category") != tag_category:
continue
# Filter by date range if specified
if start_date or end_date:
created_at_ts = pref.get("created_at")
if created_at_ts:
created_at = datetime.fromtimestamp(created_at_ts / 1000)
if start_date and created_at < start_date:
continue
if end_date and created_at > end_date:
continue
filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's four-dimension personality portrait from cache.
Args:
user_id: Target user ID
include_history: Whether to include historical trend data (ignored for cached data)
Returns:
Four-dimension personality portrait from cache
"""
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's interest area distribution from cache.
Args:
user_id: Target user ID
include_trends: Whether to include trend analysis data (ignored for cached data)
Returns:
Interest area distribution from cache
"""
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's behavioral habits from cache.
Args:
user_id: Target user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
Returns:
List of behavioral habits from cache
"""
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract habits from cache
habits = cached_profile.get("habits", [])
# Apply filters (client-side filtering on cached data)
filtered_habits = []
for habit in habits:
# Filter by confidence level
if confidence_level:
confidence_mapping = {
"high": 85,
"medium": 50,
"low": 20
}
numerical_confidence = confidence_mapping.get(confidence_level.lower())
if habit.get("confidence_level", 0) < numerical_confidence:
continue
# Filter by frequency pattern
if frequency_pattern and habit.get("frequency_pattern") != frequency_pattern:
continue
# Filter by time period
if time_period:
is_current = habit.get("is_current", True)
if time_period.lower() == "current" and not is_current:
continue
elif time_period.lower() == "past" and is_current:
continue
filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
@router.post("/generate_profile", response_model=ApiResponse)
@cur_workspace_access_guard()
async def generate_implicit_memory_profile(
request: GenerateProfileRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Generate complete user profile (all 4 modules) and cache it.
Args:
request: Generate profile request with end_user_id
db: Database session
current_user: Current authenticated user
Returns:
Complete user profile with all modules
"""
end_user_id = request.end_user_id
api_logger.info(f"Generate profile requested for user: {end_user_id}")
try:
# Validate inputs
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Generate complete profile (calls LLM for all 4 modules)
api_logger.info(f"开始生成完整用户画像: user={end_user_id}")
profile_data = await service.generate_complete_profile(user_id=end_user_id)
# Save to cache
await service.save_profile_cache(
end_user_id=end_user_id,
profile_data=profile_data,
db=db,
expires_hours=168 # 7 days
)
api_logger.info(f"用户画像生成并缓存成功: user={end_user_id}")
# Add metadata
profile_data["end_user_id"] = end_user_id
profile_data["cached"] = False
return success(data=profile_data, msg="用户画像生成成功")
except Exception as e:
api_logger.error(f"生成用户画像失败: user={end_user_id}, error={str(e)}", exc_info=True)
return handle_implicit_memory_error(e, "用户画像生成", end_user_id)

View File

@@ -1,29 +1,26 @@
from typing import Optional
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.rag.common import settings
from app.core.rag.llm.chat_model import Base
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import knowledge_model
from app.models.user_model import User
from app.models import knowledge_model, document_model, file_model
from app.schemas import knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service
from app.services.model_service import ModelConfigService
from app.core.rag.llm.chat_model import Base
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.common import settings
from app.celery_app import celery_app
# Obtain a dedicated API logger
api_logger = get_api_logger()
@@ -50,45 +47,6 @@ def get_parser_types():
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
async def get_knowledge_graph_entity_types(
llm_id: uuid.UUID,
scenario: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
get knowledge graph entity types based on llm_id
"""
api_logger.info(f"Obtain details of the knowledge graph: llm_id={llm_id}, username: {current_user.username}")
try:
# 1. Check whether the model exists
api_logger.debug(f"Check whether the model exists: {llm_id}")
config = ModelConfigService.get_model_by_id(db=db, model_id=llm_id)
if not config:
api_logger.warning(
f"The model does not exist or you do not have permission to access it: llm_id={llm_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The model does not exist or you do not have permission to access it"
)
# 2. Prepare to configure chat_mdl information
chat_model = Base(
key=config.api_keys[0].api_key,
model_name=config.api_keys[0].model_name,
base_url=config.api_keys[0].api_base
)
response = graph_entity_types(chat_model, scenario)
return success(data=response, msg="Successfully obtained knowledge graph entity types")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"get knowledge graph entity types failed: llm_id={llm_id} - {str(e)}")
raise
@router.get("/knowledges", response_model=ApiResponse)
async def get_knowledges(
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
@@ -172,7 +130,7 @@ async def get_knowledges(
"has_next": True if page*pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of knowledge base list successful")
return success(data=result, msg="Query of knowledge base list successful")
@router.post("/knowledge", response_model=ApiResponse)
@@ -198,7 +156,7 @@ async def create_knowledge(
)
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base has been successfully created")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
raise
@@ -227,7 +185,7 @@ async def get_knowledge(
)
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="Successfully obtained knowledge base information")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
except HTTPException:
raise
except Exception as e:
@@ -244,7 +202,7 @@ async def update_knowledge(
):
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
return success(data=jsonable_encoder(knowledge_schema.Knowledge.model_validate(db_knowledge)), msg="The knowledge base information has been successfully updated")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
async def _update_knowledge(
@@ -421,7 +379,7 @@ async def delete_knowledge_graph(
current_user: User = Depends(get_current_user)
):
"""
delete knowledge graph
Soft-delete knowledge graph
"""
api_logger.info(f"Request to delete knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
@@ -484,3 +442,42 @@ async def rebuild_knowledge_graph(
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/{knowledge_id}/knowledge_graph_entity_types", response_model=ApiResponse)
async def get_knowledge_graph_entity_types(
knowledge_id: uuid.UUID,
scenario: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
get knowledge graph entity types based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge graph: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(
f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. Prepare to configure chat_mdl information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
response = graph_entity_types(chat_model, scenario)
return success(data=response, msg="Successfully obtained knowledge graph entity types")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"get knowledge graph entity types failed: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -1,15 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import Optional
from typing import List, Optional
import uuid
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import App as AppSchema
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -99,8 +102,7 @@ async def get_workspace_end_users(
"""
获取工作空间的宿主列表
返回格式与原 memory_list 接口中的 end_users 字段相同
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
返回格式与原 memory_list 接口中的 end_users 字段相同
"""
workspace_id = current_user.current_workspace_id
# 获取当前空间类型
@@ -111,17 +113,6 @@ async def get_workspace_end_users(
workspace_id=workspace_id,
current_user=current_user
)
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
memory_configs_map = {}
if end_user_ids:
try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回
result = []
for end_user in end_users:
memory_num = {}
@@ -132,25 +123,10 @@ async def get_workspace_end_users(
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 只保留需要的字段,移除 error 字段(如果有)
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
'end_user':end_user,
'memory_num':memory_num
}
)
@@ -489,6 +465,7 @@ async def dashboard_data(
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = None
# 根据 storage_type 决定返回哪个数据对象
# 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null

View File

@@ -1,125 +0,0 @@
"""
情景记忆相关的控制器
包含情景记忆总览和详情查询接口
"""
from fastapi import APIRouter, Depends
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_episodic_schema import (
EpisodicMemoryOverviewRequest,
EpisodicMemoryDetailsRequest,
)
from app.services.memory_episodic_service import memory_episodic_service
# Get API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/episodic-memory",
tags=["Episodic Memory"],
)
@router.post("/overview", response_model=ApiResponse)
async def get_episodic_memory_overview_api(
request: EpisodicMemoryOverviewRequest,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取情景记忆总览
返回指定用户的所有情景记忆列表,包括标题和创建时间。
支持通过时间范围、情景类型和标题关键词进行筛选。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证参数
valid_time_ranges = ["all", "today", "this_week", "this_month"]
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
if request.time_range not in valid_time_ranges:
return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}")
if request.episodic_type not in valid_episodic_types:
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
# 处理 title_keyword去除首尾空格
title_keyword = request.title_keyword.strip() if request.title_keyword else None
api_logger.info(
f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, "
f"title_keyword={title_keyword}"
)
try:
# 调用Service层方法
result = await memory_episodic_service.get_episodic_memory_overview(
request.end_user_id, request.time_range, request.episodic_type, title_keyword
)
api_logger.info(
f"成功获取情景记忆总览: end_user_id={request.end_user_id}, "
f"total={result['total']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e))
@router.post("/details", response_model=ApiResponse)
async def get_episodic_memory_details_api(
request: EpisodicMemoryDetailsRequest,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取情景记忆详情
返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, "
f"user={current_user.username}, workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await memory_episodic_service.get_episodic_memory_details(
end_user_id=request.end_user_id,
summary_id=request.summary_id
)
api_logger.info(
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
)
return success(data=result, msg="查询成功")
except ValueError as e:
# 处理情景记忆不存在的情况
api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e))
except Exception as e:
api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e))

View File

@@ -1,115 +0,0 @@
"""
显性记忆控制器
处理显性记忆相关的API接口包括情景记忆和语义记忆的查询。
"""
from fastapi import APIRouter, Depends
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services.memory_explicit_service import MemoryExplicitService
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_explicit_schema import (
ExplicitMemoryOverviewRequest,
ExplicitMemoryDetailsRequest,
)
from app.dependencies import get_current_user
from app.models.user_model import User
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_explicit_service = MemoryExplicitService()
router = APIRouter(
prefix="/memory/explicit-memory",
tags=["Explicit Memory"],
)
@router.post("/overview", response_model=ApiResponse)
async def get_explicit_memory_overview_api(
request: ExplicitMemoryOverviewRequest,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取显性记忆总览
返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await memory_explicit_service.get_explicit_memory_overview(
request.end_user_id
)
api_logger.info(
f"成功获取显性记忆总览: end_user_id={request.end_user_id}, "
f"total={result['total']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
@router.post("/details", response_model=ApiResponse)
async def get_explicit_memory_details_api(
request: ExplicitMemoryDetailsRequest,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取显性记忆详情
根据 memory_id 返回情景记忆或语义记忆的详细信息。
- 情景记忆:包括标题、内容、情绪、创建时间
- 语义记忆:包括名称、核心定义、详细笔记、创建时间
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
f"user={current_user.username}, workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await memory_explicit_service.get_explicit_memory_details(
end_user_id=request.end_user_id,
memory_id=request.memory_id
)
api_logger.info(
f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
f"memory_type={result.get('memory_type')}"
)
return success(data=result, msg="查询成功")
except ValueError as e:
# 处理记忆不存在的情况
api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e))
except Exception as e:
api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e))

View File

@@ -1,363 +0,0 @@
"""
遗忘引擎控制器模块
本模块提供遗忘引擎的 REST API 接口,包括:
1. 手动触发遗忘周期
2. 获取和更新配置
3. 获取统计信息
4. 获取遗忘曲线数据
所有接口都需要用户认证,并自动关联到当前工作空间。
"""
from typing import Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ForgettingTriggerRequest,
ForgettingConfigResponse,
ForgettingConfigUpdateRequest,
ForgettingStatsResponse,
ForgettingReportResponse,
ForgettingCurveRequest,
ForgettingCurveResponse,
ForgettingCurvePoint,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/forget-memory",
tags=["Memory Forgetting Engine"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)
# 初始化服务
forget_service = MemoryForgetService()
# ==================== API 端点 ====================
@router.post("/trigger", response_model=ApiResponse)
async def trigger_forgetting_cycle(
payload: ForgettingTriggerRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
手动触发遗忘周期
执行一次完整的遗忘周期,识别并融合低激活值节点。
Args:
payload: 触发请求参数
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含遗忘报告的响应
"""
workspace_id = current_user.current_workspace_id
end_user_id = payload.end_user_id # 从 payload 中获取 end_user_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
f"end_user_id={end_user_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
f"min_days={payload.min_days_since_access}"
)
try:
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
group_id=end_user_id, # 服务层方法的参数名是 group_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id
)
# 构建响应
response_data = ForgettingReportResponse(**report)
return success(data=response_data.model_dump(), msg="遗忘周期执行成功")
except RuntimeError as e:
api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError")
except Exception as e:
api_logger.error(f"触发遗忘周期失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e))
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘引擎配置
读取指定配置ID的遗忘引擎参数。
Args:
config_id: 配置ID
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含配置信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}"
)
try:
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
# 构建响应
response_data = ForgettingConfigResponse(**config)
return success(data=response_data.model_dump(), msg="查询成功")
except ValueError as e:
api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e))
except Exception as e:
api_logger.error(f"读取遗忘引擎配置失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse)
async def update_forgetting_config(
payload: ForgettingConfigUpdateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
更新遗忘引擎配置
更新指定配置ID的遗忘引擎参数。
Args:
payload: 配置更新请求
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}"
)
try:
# 构建更新字段字典(排除 None 值和 config_id
update_data = {
key: value
for key, value in payload.model_dump(exclude_none=True).items()
if key != 'config_id'
}
# 调用服务层更新配置
config = forget_service.update_forgetting_config(
db=db,
config_id=payload.config_id,
update_fields=update_data
)
# 构建响应
response_data = ForgettingConfigResponse(**config)
return success(data=response_data.model_dump(), msg="更新成功")
except ValueError as e:
api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
db.rollback()
api_logger.error(f"更新遗忘引擎配置失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
group_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘引擎统计信息
返回知识层节点统计、激活值分布等信息。
Args:
group_id: 组ID即 end_user_id可选
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id
config_id = None
if group_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
group_id=group_id,
config_id=config_id
)
# 构建响应
response_data = ForgettingStatsResponse(**stats)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取遗忘引擎统计失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve(
request: ForgettingCurveRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘曲线数据
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
Args:
request: 遗忘曲线请求参数
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: "
f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}"
)
try:
# 调用服务层生成遗忘曲线
result = await forget_service.get_forgetting_curve(
db=db,
importance_score=request.importance_score,
days=request.days,
config_id=request.config_id
)
# 转换为响应格式
curve_points = [
ForgettingCurvePoint(**point)
for point in result['curve_data']
]
# 构建响应
response_data = ForgettingCurveResponse(
curve_data=curve_points,
config=result['config']
)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取遗忘曲线失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e))

View File

@@ -1,255 +0,0 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.memory_perceptual_model import PerceptualType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualFilter
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_perceptual_service import MemoryPerceptualService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/perceptual",
tags=["Perceptual Memory System"],
dependencies=[Depends(get_current_user)]
)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
group_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
return success(
data=count_stats,
msg="Memory statistics retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}")
return success(
data=None,
msg="No visual memory available"
)
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
return success(
data=visual_memory,
msg="Latest visual memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}")
return success(
data=None,
msg="No audio memory available"
)
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
return success(
data=audio_memory,
msg="Latest audio memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id)
if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}")
return success(
data=None,
msg="No text memory available"
)
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
return success(
data=text_memory,
msg="Latest text memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
group_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve a timeline of perceptual memories for a user group.
Args:
group_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Timeline data of perceptual memories
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}"
)
try:
query = PerceptualQuerySchema(
filter=PerceptualFilter(type=perceptual_type),
page=page,
page_size=page_size
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
f"returned={len(timeline_data.memories)}"
)
return success(
data=timeline_data.model_dump(),
msg="Perceptual memory timeline retrieved successfully"
)
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"error={str(e)}"
)
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch perceptual memory timeline",
)

View File

@@ -1,43 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
)
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -1,3 +1,4 @@
import datetime
import os
import uuid
from typing import Optional
@@ -8,7 +9,12 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.end_user_model import EndUser
from app.models.user_model import User
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
)
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
@@ -16,6 +22,8 @@ from app.schemas.memory_storage_schema import (
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
GenerateCacheRequest,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_storage_service import (
@@ -230,8 +238,28 @@ def update_config_extracted(
# --- Forget config params ---
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
def update_config_forget(
payload: ConfigUpdateForget,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
try:
svc = DataConfigService(db)
result = svc.update_forget(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
@@ -255,6 +283,28 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_forget(
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
try:
svc = DataConfigService(db)
result = svc.get_forget(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config(
current_user: User = Depends(get_current_user),

View File

@@ -1,134 +0,0 @@
import uuid
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.schemas.response_schema import ApiResponse
from app.services.conversation_service import ConversationService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/work",
tags=["Working Memory System"],
dependencies=[Depends(get_current_user)]
)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{group_id}/conversations", response_model=ApiResponse)
def get_conversations(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve all conversations for the current user in a specific group.
Args:
group_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains a list of conversation IDs.
Notes:
- Initializes the ConversationService with the current DB session.
- Returns only conversation IDs for lightweight response.
- Logs can be added to trace requests in production.
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
group_id
)
return success(data=[
{
"id": conversation.id,
"title": conversation.title
} for conversation in conversations
], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve the message history for a specific conversation.
Args:
conversation_id (UUID): The ID of the conversation to fetch messages from.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains the list of messages in the conversation.
Notes:
- Uses ConversationService to fetch messages.
- Consider paginating results if message history is large.
- Logging can be added for audit and debugging.
"""
conversation_service = ConversationService(db)
messages_obj = conversation_service.get_messages(
conversation_id,
)
messages = [
{
"role": message.role,
"content": message.content,
"created_at": int(message.created_at.timestamp() * 1000),
}
for message in messages_obj
]
return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve detailed information about a specific conversation.
This endpoint will fetch the conversation detail for the user. If the detail
does not exist or is outdated, it will trigger the LLM to generate a new summary.
Args:
conversation_id (UUID): The ID of the conversation.
current_user (User, optional): The authenticated user making the request.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains the conversation detail serialized as a dictionary.
Notes:
- Uses async ConversationService to fetch or generate the conversation detail.
- Handles workspace and user-specific context automatically.
- Logging and exception handling should be implemented for production monitoring.
"""
conversation_service = ConversationService(db)
detail = await conversation_service.get_conversation_detail(
user=current_user,
conversation_id=conversation_id,
workspace_id=current_user.current_workspace_id
)
return success(data=detail.model_dump(), msg="get conversation detail success")

View File

@@ -74,7 +74,7 @@ def get_multi_agent_configs(
"app_id": str(app_id),
"default_model_config_id": None,
"model_parameters": None,
"orchestration_mode": "supervisor",
"orchestration_mode": "conditional",
"sub_agents": [],
"routing_rules": [],
"execution_config": {

View File

@@ -1,9 +1,7 @@
import uuid
import json
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
@@ -72,12 +70,12 @@ def get_prompt_session(
SessionMessage(role=role, content=content)
for role, content in history
]
result = SessionHistoryResponse(
session_id=session_id,
messages=messages
)
return success(data=result)
@@ -106,32 +104,35 @@ async def get_prompt_opt(
ApiResponse: Contains the optimized prompt, description, and a list of variables.
"""
service = PromptOptimizerService(db)
async def event_generator():
yield "event:start\ndata: {}\n\n"
try:
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event:error\ndata: {json.dumps(
{"error": str(e)}
)}\n\n"
yield "event:end\ndata: {}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.USER,
content=data.message
)
opt_result = await service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.ASSISTANT,
content=opt_result.desc
)
variables = service.parser_prompt_variables(opt_result.prompt)
result = {
"prompt": opt_result.prompt,
"desc": opt_result.desc,
"variables": variables
}
result_schema = OptimizePromptResponse.model_validate(result)
return success(data=result_schema)

View File

@@ -1,7 +1,6 @@
import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
@@ -18,8 +17,6 @@ from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -268,8 +265,7 @@ def get_conversation(
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
db: Session = Depends(get_db)
):
"""发送消息并获取回复
@@ -311,7 +307,7 @@ async def chat(
other_id=other_id,
original_user_id=user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
appid=share.app_id
"""获取存储类型和工作空间的ID"""
@@ -365,9 +361,6 @@ async def chat(
config = release.config or {}
if not config.get("sub_agents"):
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
elif app_type == AppType.WORKFLOW:
# Multi-Agent 类型:验证多 Agent 配置
pass
else:
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@@ -396,96 +389,15 @@ async def chat(
if app_type == AppType.AGENT:
# 流式返回
agent_config = agent_config_4_app_release(release)
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=agent_config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# config = workflow_config_4_app_release(release)
config = multi_agent_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
async for event in service.chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
password=password,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
@@ -503,89 +415,37 @@ async def chat(
}
)
# 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat(
# 非流式返回
result = await service.chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
password=password,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
async for event in service.multi_agent_chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
password=password,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data, default=str, ensure_ascii=False)}\n\n"
yield sse_message
yield event
return StreamingResponse(
event_generator(),
@@ -598,34 +458,22 @@ async def chat(
)
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
result = await service.multi_agent_chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
password=password,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
return success(data=conversation_schema.ChatResponse(**result))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -4,17 +4,14 @@
认证方式: API Key
"""
from fastapi import APIRouter
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
from . import app_api_controller, rag_api_controller, memory_api_controller
# 创建 V1 API 路由器
service_router = APIRouter()
# 注册子路由
service_router.include_router(app_api_controller.router)
service_router.include_router(rag_api_knowledge_controller.router)
service_router.include_router(rag_api_document_controller.router)
service_router.include_router(rag_api_file_controller.router)
service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(rag_api_controller.router)
service_router.include_router(memory_api_controller.router)
__all__ = ["service_router"]

View File

@@ -1,5 +1,4 @@
"""App 服务接口 - 基于 API Key 认证"""
import json
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Body
@@ -22,7 +21,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
@@ -138,10 +137,10 @@ async def chat(
if app_type == AppType.AGENT:
# print("="*50)
# print(app.current_release.default_model_config_id)
print("="*50)
print(app.current_release.default_model_config_id)
agent_config = agent_config_4_app_release(app.current_release)
# print(agent_config.default_model_config_id)
print(agent_config.default_model_config_id)
# 流式返回
if payload.stream:
async def event_generator():
@@ -154,8 +153,7 @@ async def chat(
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -179,13 +177,12 @@ async def chat(
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
config = multi_agent_config_4_app_release(app.current_release)
config = dict_to_multi_agent_config(app.current_release.config,app.id)
if payload.stream:
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
@@ -197,8 +194,8 @@ async def chat(
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -214,6 +211,7 @@ async def chat(
# 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
@@ -228,29 +226,22 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.WORKFLOW:
# 多 Agent 流式返回
config = workflow_config_4_app_release(app.current_release)
config = dict_to_workflow_config(app.current_release.config,app.id)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
yield event
return StreamingResponse(
event_generator(),
@@ -262,34 +253,23 @@ async def chat(
}
)
# 多 Agent 非流式返回
# 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -1,221 +0,0 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from typing import Any, Optional, Union
import uuid
from fastapi import APIRouter, Body, Depends, Request, status, Query
from sqlalchemy.orm import Session
from app.controllers import chunk_controller
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.core.rag.models.chunk import QAChunk
from app.core.response_utils import success
from app.db import get_db
from app.schemas import chunk_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.response_schema import ApiResponse
from app.services import api_key_service
router = APIRouter(prefix="/chunks", tags=["V1 - RAG API"])
api_logger = get_business_logger()
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_preview_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
):
"""
Paged query document block preview list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.get_preview_chunks(kb_id=kb_id,
document_id=document_id,
page=page,
pagesize=pagesize,
keywords=keywords,
db=db,
current_user=current_user)
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content")
):
"""
Paged query document chunk list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.get_chunks(kb_id=kb_id,
document_id=document_id,
page=page,
pagesize=pagesize,
keywords=keywords,
db=db,
current_user=current_user)
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
):
"""
create chunk
"""
body = await request.json()
create_data = chunk_schema.ChunkCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.create_chunk(kb_id=kb_id,
document_id=document_id,
create_data=create_data,
db=db,
current_user=current_user)
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Retrieve document chunk information based on doc_id
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.get_chunk(kb_id=kb_id,
document_id=document_id,
doc_id=doc_id,
db=db,
current_user=current_user)
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def update_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
content: Union[str, QAChunk] = Body(..., description="Content can be either a string or a QAChunk object"),
):
"""
Update document chunk content
"""
body = await request.json()
update_data = chunk_schema.ChunkUpdate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.update_chunk(kb_id=kb_id,
document_id=document_id,
doc_id=doc_id,
update_data=update_data,
db=db,
current_user=current_user)
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def delete_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
delete document chunk
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.delete_chunk(kb_id=kb_id,
document_id=document_id,
doc_id=doc_id,
db=db,
current_user=current_user)
@router.get("/retrieve_type", response_model=ApiResponse)
def get_retrieve_types():
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
@require_api_key(scopes=["rag"])
async def retrieve_chunks(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
query: str = Body(..., description="question"),
):
"""
retrieve chunk
"""
body = await request.json()
retrieve_data = chunk_schema.ChunkRetrieve(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.retrieve_chunks(retrieve_data=retrieve_data,
db=db,
current_user=current_user)

View File

@@ -0,0 +1,16 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
logger = get_business_logger()
@router.get("")
async def list_knowledge():
"""列出可访问的知识库(占位)"""
return success(data=[], msg="RAG API - Coming Soon")

View File

@@ -1,172 +0,0 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from typing import Optional
import uuid
from fastapi import APIRouter, Body, Depends, Request, Query
from sqlalchemy.orm import Session
from app.controllers import document_controller
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.schemas import document_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.response_schema import ApiResponse
from app.services import api_key_service
router = APIRouter(prefix="/documents", tags=["V1 - RAG API"])
api_logger = get_business_logger()
@router.get("/{kb_id}/documents", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_documents(
kb_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
document_ids: Optional[str] = Query(None, description="document ids, separated by commas")
):
"""
Paged query document list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.get_documents(kb_id=kb_id,
parent_id=parent_id,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
keywords=keywords,
document_ids=document_ids,
db=db,
current_user=current_user)
@router.post("/document", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_document(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
kb_id: uuid.UUID = Body(..., description="kb id"),
file_name: str = Body(..., description="file name"),
):
"""
create document
"""
body = await request.json()
create_data = document_schema.DocumentCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.create_document(create_data=create_data,
db=db,
current_user=current_user)
@router.get("/{document_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_document(
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Retrieve document information based on document_id
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.get_document(document_id=document_id,
db=db,
current_user=current_user)
@router.put("/{document_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def update_document(
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
file_name: str = Body(None, description="file name (optional)"),
):
"""
Update document information
"""
body = await request.json()
update_data = document_schema.DocumentUpdate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.update_document(document_id=document_id,
update_data=update_data,
db=db,
current_user=current_user)
@router.delete("/{document_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def delete_document(
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Delete document
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.delete_document(document_id=document_id,
db=db,
current_user=current_user)
@router.post("/{document_id}/chunks", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def parse_documents(
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
parse document
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await document_controller.parse_documents(document_id=document_id,
db=db,
current_user=current_user)

View File

@@ -1,198 +0,0 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Body, Depends, Request, Query, File, UploadFile
from sqlalchemy.orm import Session
from app.controllers import file_controller
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.schemas import file_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.response_schema import ApiResponse
from app.services import api_key_service
router = APIRouter(prefix="/files", tags=["V1 - RAG API"])
api_logger = get_business_logger()
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_files(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
):
"""
Paged query file list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id=api_key_auth.workspace_id
return await file_controller.get_files(kb_id=kb_id,
parent_id=parent_id,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
keywords=keywords,
db=db,
current_user=current_user)
@router.post("/folder", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_folder(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
folder_name: str = '/'
):
"""
Create a new folder
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await file_controller.create_folder(kb_id=kb_id,
parent_id=parent_id,
folder_name=folder_name,
db=db,
current_user=current_user)
@router.post("/file", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def upload_file(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
file: UploadFile = File(...),
):
"""
upload file
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await file_controller.upload_file(kb_id=kb_id,
parent_id=parent_id,
file=file,
db=db,
current_user=current_user)
@router.post("/customtext", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def custom_text(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
title: str = Body(..., description="title"),
content: str = Body(..., description="content"),
):
"""
custom text
"""
body = await request.json()
create_data = file_schema.CustomTextFileCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await file_controller.custom_text(kb_id=kb_id,
parent_id=parent_id,
create_data=create_data,
db=db,
current_user=current_user)
@router.get("/{file_id}", response_model=Any)
async def get_file(
file_id: uuid.UUID,
db: Session = Depends(get_db)
) -> Any:
"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
return await file_controller.get_file(file_id=file_id,
db=db)
@router.put("/{file_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def update_file(
file_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
file_name: str = Body(None, description="file name (optional)"),
):
"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
body = await request.json()
update_data = file_schema.FileUpdate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await file_controller.update_file(file_id=file_id,
update_data=update_data,
db=db,
current_user=current_user)
@router.delete("/{file_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def delete_file(
file_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Delete a file or folder
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await file_controller.delete_file(file_id=file_id,
db=db,
current_user=current_user)

View File

@@ -1,248 +0,0 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from typing import Optional, Dict
import uuid
from fastapi import APIRouter, Body, Depends, Request, Query
from sqlalchemy.orm import Session
from app.controllers import knowledge_controller
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.models import knowledge_model
from app.schemas import knowledge_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.response_schema import ApiResponse
from app.services import api_key_service
router = APIRouter(prefix="/knowledges", tags=["V1 - RAG API"])
api_logger = get_business_logger()
@router.get("/knowledgetype", response_model=ApiResponse)
def get_knowledge_types():
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
@router.get("/permissiontype", response_model=ApiResponse)
def get_permission_types():
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
@router.get("/parsertype", response_model=ApiResponse)
def get_parser_types():
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
@router.get("/knowledge_graph_entity_types", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_knowledge_graph_entity_types(
llm_id: uuid.UUID,
scenario: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
get knowledge graph entity types based on llm_id
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.get_knowledge_graph_entity_types(llm_id=llm_id,
scenario=scenario,
db=db,
current_user=current_user)
@router.get("/knowledges", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_knowledges(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas")
):
"""
Query the knowledge base list in pages
- Support filtering by parent_id
- Support keyword search for knowledge base names
- Support dynamic sorting
- Return paging metadata + file list
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.get_knowledges(parent_id=parent_id,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
keywords=keywords,
kb_ids=kb_ids,
db=db,
current_user=current_user)
@router.post("/knowledge", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_knowledge(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
name: str = Body(..., description="KB name"),
):
"""
create knowledge
"""
body = await request.json()
create_data = knowledge_schema.KnowledgeCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.create_knowledge(create_data=create_data,
db=db,
current_user=current_user)
@router.get("/{knowledge_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_knowledge(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Retrieve knowledge base information based on knowledge_id
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.get_knowledge(knowledge_id=knowledge_id,
db=db,
current_user=current_user)
@router.put("/{knowledge_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def update_knowledge(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
name: str = Body(None, description="KB name (optional)"),
):
body = await request.json()
update_data = knowledge_schema.KnowledgeUpdate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.update_knowledge(knowledge_id=knowledge_id,
update_data=update_data,
db=db,
current_user=current_user)
@router.delete("/{knowledge_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def delete_knowledge(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Soft-delete knowledge base
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.delete_knowledge(knowledge_id=knowledge_id,
db=db,
current_user=current_user)
@router.get("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_knowledge_graph(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Retrieve knowledge_graph base information based on knowledge_id
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.get_knowledge_graph(knowledge_id=knowledge_id,
db=db,
current_user=current_user)
@router.delete("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def delete_knowledge_graph(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
delete knowledge graph
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.delete_knowledge_graph(knowledge_id=knowledge_id,
db=db,
current_user=current_user)
@router.post("/{knowledge_id}/knowledge_graph", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def rebuild_knowledge_graph(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
rebuild knowledge graph
"""
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.rebuild_knowledge_graph(knowledge_id=knowledge_id,
db=db,
current_user=current_user)

View File

@@ -1,22 +1,23 @@
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Depends, status, Query, HTTPException
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
from app.db import get_db
from app.models.models_model import ModelApiKey
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import AppChatRequest
from app.services.model_service import ModelConfigService
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
from app.services.conversation_service import ConversationService
from app.core.logging_config import get_api_logger
from app.dependencies import get_current_user
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
@@ -27,8 +28,6 @@ router = APIRouter(
)
# ==================== 原有测试接口 ====================
@router.get("/llm/{model_id}", response_model=ApiResponse)
def test_llm(
model_id: uuid.UUID,
@@ -51,6 +50,7 @@ def test_llm(
template = """Question: {question}
Answer: Let's think step by step."""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
answer = chain.invoke({"question": "What is LangChain?"})
@@ -80,13 +80,13 @@ def test_embedding(
base_url=apiConfig.api_base
))
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
embeddings = model.embed_documents(data)
print(embeddings)
query = "我想找一个适合学习的地方。"
@@ -114,123 +114,13 @@ def test_rerank(
base_url=apiConfig.api_base
))
query = "最近哪家咖啡店评价最好?"
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
scores = model.rerank(query=query, documents=data, top_n=3)
print(scores)
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
# ==================== Handoffs 测试接口 ====================
@router.post("/handoffs/{app_id}")
async def test_handoffs(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: AppChatRequest = Body(...),
current_user=Depends(get_current_user),
db: Session = Depends(get_db)
):
"""测试 Agent Handoffs 功能
演示 LangGraph 实现的多 Agent 协作和动态切换
- 从数据库 multi_agent_config 获取 Agent 配置
- 根据用户问题自动切换到合适的 Agent
- 使用 conversation_id 保持会话状态
- 通过 stream 参数控制是否流式输出
事件类型(流式):
- start: 开始执行
- agent: 当前 Agent 信息
- message: 流式消息内容
- handoff: Agent 切换事件
- end: 执行结束
- error: 错误信息
"""
try:
workspace_id = current_user.current_workspace_id
# 获取或创建会话
conversation_service = ConversationService(db)
if request.conversation_id:
# 验证会话存在
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
conversation_id = str(conversation.id)
else:
# 创建新会话
conversation = conversation_service.create_or_get_conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=request.user_id,
is_draft=True
)
conversation_id = str(conversation.id)
# 根据 stream 参数决定返回方式
if request.stream:
# 流式返回
service = get_handoffs_service_for_app(app_id, db, streaming=True)
return StreamingResponse(
service.chat_stream(
message=request.message,
conversation_id=conversation_id
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
else:
# 非流式返回
service = get_handoffs_service_for_app(app_id, db, streaming=False)
result = await service.chat(
message=request.message,
conversation_id=conversation_id
)
return success(data=result, msg="Handoffs 测试成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Handoffs 测试失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
def get_handoff_agents(
app_id: uuid.UUID = Path(..., description="应用 ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user)
):
"""获取应用的 Handoff Agent 列表"""
try:
service = get_handoffs_service_for_app(app_id, db, streaming=False)
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/handoffs/{app_id}/reset")
def reset_handoff_service(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user=Depends(get_current_user)
):
"""重置指定应用的 Handoff 服务缓存"""
reset_handoffs_service_cache(app_id)
return success(msg="Handoff 服务已重置")

View File

@@ -60,22 +60,6 @@ async def list_tools(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}/methods", response_model=ApiResponse)
async def get_tool_methods(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具的所有方法"""
try:
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
if methods is None:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=methods, msg="获取工具方法成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool(
tool_id: str,
@@ -175,8 +159,7 @@ async def execute_tool(
workspace_id=current_user.current_workspace_id,
timeout=request.timeout
)
if not result.success:
raise HTTPException(status_code=400, detail=result["error"])
return success(
data={
"success": result.success,
@@ -215,8 +198,8 @@ async def sync_mcp_tools(
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if not result.get("success", False):
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
if result["success"] is False:
raise HTTPException(status_code=404, detail=result["message"])
return success(data=result, msg="MCP工具列表同步完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -11,16 +11,14 @@ from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.user_memory_service import (
UserMemoryService,
analytics_node_statistics,
analytics_memory_types,
analytics_graph_data,
)
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -43,27 +41,24 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: str,
end_user_id: str, # 使用 end_user_id
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取缓存的记忆洞察报告
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
如需生成新的洞察报告,请使用专门的生成接口。
"""
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""获取缓存的记忆洞察报告"""
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="数据尚未生成")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
@@ -71,27 +66,24 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str,
end_user_id: str, # 使用 end_user_id
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取缓存的用户摘要
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
"""
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""获取缓存的用户摘要"""
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="数据尚未生成")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
@@ -105,35 +97,35 @@ async def generate_cache_api(
) -> dict:
"""
手动触发缓存生成
- 如果提供 end_user_id只为该用户生成
- 如果不提供,为当前工作空间的所有用户生成
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}"
)
try:
if group_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
# 构建响应
result = {
"end_user_id": group_id,
@@ -141,7 +133,7 @@ async def generate_cache_api(
"summary_success": summary_result["success"],
"errors": []
}
# 收集错误信息
if not insight_result["success"]:
result["errors"].append({
@@ -153,29 +145,29 @@ async def generate_cache_api(
"type": "summary",
"error": summary_result.get("error")
})
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存")
else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")
else:
# 为整个工作空间生成
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
# 记录统计信息
api_logger.info(
f"工作空间 {workspace_id} 批量生成完成: "
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
)
return success(data=result, msg="批量生成完成")
except Exception as e:
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
@@ -188,18 +180,18 @@ async def get_node_statistics_api(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try:
# 调用新的记忆类型统计函数
result = await analytics_memory_types(db, end_user_id)
# 计算总数用于日志
total_count = sum(item["count"] for item in result)
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
@@ -219,31 +211,31 @@ async def get_graph_data_api(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 参数验证
if limit > 1000:
limit = 1000
api_logger.warning("limit 参数超过最大值,已调整为 1000")
if depth > 3:
depth = 3
api_logger.warning("depth 参数超过最大值,已调整为 3")
# 解析 node_types 参数
node_types_list = None
if node_types:
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
api_logger.info(
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
)
try:
result = await analytics_graph_data(
db=db,
@@ -253,19 +245,19 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
return success(data=result, msg=result.get("message", "查询成功"))
api_logger.info(
f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, "
f"edges={result['statistics']['total_edges']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
@@ -278,25 +270,25 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -308,10 +300,10 @@ async def get_end_user_profile(
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
return success(data=profile_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
@@ -325,56 +317,49 @@ async def update_end_user_profile(
) -> dict:
"""
更新终端用户的基本信息
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间戳(毫秒)
current_timestamp = int(datetime.datetime.now().timestamp() * 1000)
end_user.updatetime_profile = current_timestamp
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -386,51 +371,11 @@ async def update_end_user_profile(
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}")
return success(data=profile_data.model_dump(), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
# 获取情绪数据
emotion = MemoryEmotion(id, label)
emotion_result = await emotion.get_emotion()
# 获取交互数据
interaction = MemoryInteraction(id, label)
interaction_result = await interaction.get_interaction_frequency()
# 关闭连接
await emotion.close()
await interaction.close()
result = {
"emotion": emotion_result,
"interaction": interaction_result
}
api_logger.info(f"关系演变查询成功: id={id}, table={label}")
return success(data=result, msg="关系演变")
except Exception as e:
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e))

View File

@@ -39,11 +39,11 @@ router = APIRouter(prefix="/apps", tags=["workflow"])
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
@@ -96,7 +96,6 @@ async def create_workflow_config(
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
@@ -200,10 +199,10 @@ async def create_workflow_config(
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
@@ -244,11 +243,11 @@ async def delete_workflow_config(
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
@@ -313,12 +312,12 @@ async def validate_workflow_config(
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
@@ -366,10 +365,10 @@ async def get_workflow_executions(
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
@@ -418,14 +417,16 @@ async def get_workflow_execution(
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
@@ -486,22 +487,22 @@ async def run_workflow(
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
@@ -553,10 +554,10 @@ async def run_workflow(
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
@@ -601,7 +602,7 @@ async def cancel_workflow_execution(
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.code, msg=e.message)
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(

View File

@@ -11,16 +11,10 @@ import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
@@ -165,13 +159,11 @@ class LangChainAgent:
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
return messagss_list
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
@@ -211,6 +203,7 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
@@ -228,26 +221,11 @@ class LangChainAgent:
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
history_term_memory=await self.term_memory_redis_read(end_user_id)
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
history_term_memory=';'.join(history_term_memory)
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try:
@@ -336,6 +314,10 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
@@ -347,24 +329,14 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
history_term_memory = await self.term_memory_redis_read(end_user_id)
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try:

View File

@@ -3,7 +3,7 @@ import secrets
from typing import Optional, Union
from datetime import datetime
from app.models.api_key_model import ApiKeyType
from app.schemas.api_key_schema import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse

View File

@@ -7,18 +7,17 @@ from dotenv import load_dotenv
load_dotenv()
class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
# Database configuration (Postgres)
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
@@ -38,7 +37,7 @@ class Settings:
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
@@ -49,7 +48,7 @@ class Settings:
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
# Xinference configuration
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
@@ -58,17 +57,17 @@ class Settings:
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
# LLM Request Configuration
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
# JWT Token Configuration
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
@@ -87,19 +86,19 @@ class Settings:
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
# ========================================================================
# Superuser settings (internal defaults)
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
# Generic File Upload (internal)
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
@@ -124,7 +123,7 @@ class Settings:
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
# Sensitive Data Filtering
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
@@ -143,6 +142,7 @@ class Settings:
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
@@ -150,27 +150,21 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.
@@ -185,7 +179,7 @@ class Settings:
if filename:
return str(base_path / filename)
return str(base_path)
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.

View File

@@ -82,13 +82,6 @@ class BizCode(IntEnum):
MEMORY_WRITE_FAILED = 9501
MEMORY_READ_FAILED = 9502
MEMORY_CONFIG_NOT_FOUND = 9503
# Implicit Memory API96xx
INVALID_USER_ID = 9601
INSUFFICIENT_DATA = 9602
INVALID_FILTER_PARAMS = 9603
ANALYSIS_FAILED = 9604
PROFILE_STORAGE_ERROR = 9605
# 系统100xx
INTERNAL_ERROR = 10001
@@ -110,24 +103,24 @@ HTTP_MAPPING = {
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 400,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 400,
BizCode.NOT_FOUND: 404,
BizCode.USER_NOT_FOUND: 200,
BizCode.WORKSPACE_NOT_FOUND: 400,
BizCode.MODEL_NOT_FOUND: 400,
BizCode.KNOWLEDGE_NOT_FOUND: 400,
BizCode.DOCUMENT_NOT_FOUND: 400,
BizCode.FILE_NOT_FOUND: 400,
BizCode.APP_NOT_FOUND: 400,
BizCode.RELEASE_NOT_FOUND: 400,
BizCode.WORKSPACE_NOT_FOUND: 404,
BizCode.MODEL_NOT_FOUND: 404,
BizCode.KNOWLEDGE_NOT_FOUND: 404,
BizCode.DOCUMENT_NOT_FOUND: 404,
BizCode.FILE_NOT_FOUND: 404,
BizCode.APP_NOT_FOUND: 404,
BizCode.RELEASE_NOT_FOUND: 404,
BizCode.DUPLICATE_NAME: 409,
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
BizCode.AGENT_CONFIG_MISSING: 400,
BizCode.SHARE_DISABLED: 403,
@@ -166,13 +159,6 @@ HTTP_MAPPING = {
BizCode.MEMORY_READ_FAILED: 500,
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
# Implicit Memory API 错误码映射
BizCode.INVALID_USER_ID: 400,
BizCode.INSUFFICIENT_DATA: 400,
BizCode.INVALID_FILTER_PARAMS: 400,
BizCode.ANALYSIS_FAILED: 500,
BizCode.PROFILE_STORAGE_ERROR: 500,
BizCode.INTERNAL_ERROR: 500,
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,

View File

@@ -106,32 +106,28 @@ class SearchService:
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering
group_id: Group identifier for filtering results
question: Search query text
limit: Max results per category (default: 15)
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: BM25 weight (default: 0.6)
activation_boost_factor: Activation impact on memory strength (default: 0.8)
output_path: JSON output path (default: "search_results.json")
return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
Returns:
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
@@ -155,7 +151,6 @@ class SearchService:
output_path=output_path,
memory_config=config,
rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
)
# Extract results based on search type and include parameter

View File

@@ -425,9 +425,15 @@ async def Input_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
@@ -533,11 +539,31 @@ async def Input_Summary(
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {

View File

@@ -5,16 +5,19 @@ This module provides analytics and insights for the memory system.
Available functions:
- get_hot_memory_tags: Get hot memory tags by frequency
- MemoryInsight: Generate memory insight reports
- get_recent_activity_stats: Get recent activity statistics
Note: MemoryInsight and generate_user_summary have been moved to
app.services.user_memory_service for better architecture.
- generate_user_summary: Generate user summary
"""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
__all__ = [
"get_hot_memory_tags",
"MemoryInsight",
"get_recent_activity_stats",
"generate_user_summary",
]

View File

@@ -1,6 +0,0 @@
"""Implicit Memory Module
This module provides behavior analysis capabilities that build comprehensive user profiles
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
"""

View File

@@ -1 +0,0 @@
"""Analyzers package for implicit memory analysis components."""

View File

@@ -1,271 +0,0 @@
"""Dimension Analyzer for Implicit Memory System
This module implements LLM-based personality dimension analysis from user memory summaries.
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
providing percentage scores with evidence and reasoning.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
DimensionPortrait,
DimensionScore,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class DimensionData(BaseModel):
"""Individual dimension analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
reasoning: str = ""
confidence_level: int = 50 # Default to medium confidence
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
class DimensionAnalyzer:
"""Analyzes user memory summaries to extract personality dimensions."""
# Define the four dimensions we analyze
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the dimension analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_dimensions(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_portrait: Optional[DimensionPortrait] = None
) -> DimensionPortrait:
"""Analyze user summaries to extract personality dimensions.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_portrait: Optional existing portrait for incremental updates
Returns:
Dimension portrait with four personality dimensions
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_portrait(user_id)
try:
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_dimensions(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create dimension scores
dimension_scores = {}
current_time = datetime.now()
for dimension_name in self.DIMENSIONS:
# Handle response as dictionary
dimensions_data = response.get("dimensions", {})
dimension_data = dimensions_data.get(dimension_name)
if dimension_data:
# Validate and create dimension score
score = self._create_dimension_score(
dimension_name=dimension_name,
dimension_data=dimension_data
)
dimension_scores[dimension_name] = score
else:
# Create default score if missing
logger.warning(f"Missing dimension data for {dimension_name}, using default")
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
# Create dimension portrait
portrait = DimensionPortrait(
user_id=user_id,
creativity=dimension_scores["creativity"],
aesthetic=dimension_scores["aesthetic"],
technology=dimension_scores["technology"],
literature=dimension_scores["literature"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries),
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
)
logger.info(f"Created dimension portrait for user {user_id}")
return portrait
except LLMClientException:
raise
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
def _create_dimension_score(
self,
dimension_name: str,
dimension_data: dict
) -> DimensionScore:
"""Create a dimension score from analysis data.
Args:
dimension_name: Name of the dimension
dimension_data: Analysis data dictionary for the dimension
Returns:
DimensionScore object
"""
# Validate percentage - handle dict access
percentage = dimension_data.get("percentage", 0.0)
percentage = max(0.0, min(100.0, float(percentage)))
# Validate confidence level
confidence_level = self._validate_confidence_level(dimension_data.get("confidence_level", 50))
# Ensure evidence is not empty
evidence = dimension_data.get("evidence", [])
if not evidence:
evidence = ["No specific evidence found"]
# Ensure reasoning is not empty
reasoning = dimension_data.get("reasoning", "")
if not reasoning:
reasoning = f"Analysis for {dimension_name} dimension"
return DimensionScore(
dimension_name=dimension_name,
percentage=percentage,
evidence=evidence,
reasoning=reasoning,
confidence_level=confidence_level
)
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
"""Create a default dimension score when analysis fails.
Args:
dimension_name: Name of the dimension
Returns:
Default DimensionScore object
"""
return DimensionScore(
dimension_name=dimension_name,
percentage=0.0,
evidence=["Insufficient data for analysis"],
reasoning=f"No clear evidence found for {dimension_name} dimension",
confidence_level=20 # Low confidence as numerical value
)
def _validate_confidence_level(self, confidence_level) -> int:
"""Return confidence level as integer, handling both string and numeric inputs.
Args:
confidence_level: Confidence level (string or numeric)
Returns:
Confidence level as integer (0-100)
"""
# If it's already a number, return it as int
if isinstance(confidence_level, (int, float)):
return int(confidence_level)
# If it's a string, convert common values to numbers
if isinstance(confidence_level, str):
confidence_str = confidence_level.lower().strip()
if confidence_str in ["high", "높음"]:
return 85
elif confidence_str in ["medium", "중간"]:
return 50
elif confidence_str in ["low", "낮음"]:
return 20
else:
# Try to parse as number
try:
return int(float(confidence_str))
except ValueError:
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
return 50
# Default fallback
return 50
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
"""Create an empty dimension portrait when no data is available.
Args:
user_id: Target user ID
Returns:
Empty DimensionPortrait
"""
current_time = datetime.now()
return DimensionPortrait(
user_id=user_id,
creativity=self._create_default_dimension_score("creativity"),
aesthetic=self._create_default_dimension_score("aesthetic"),
technology=self._create_default_dimension_score("technology"),
literature=self._create_default_dimension_score("literature"),
analysis_timestamp=current_time,
total_summaries_analyzed=0,
historical_trends=None
)
def _calculate_historical_trends(
self,
existing_portrait: DimensionPortrait
) -> List[Dict[str, Any]]:
"""Calculate historical trends from existing portrait.
Args:
existing_portrait: Previous dimension portrait
Returns:
List of historical trend data
"""
if not existing_portrait:
return []
# Create trend entry from existing portrait
trend_entry = {
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
"creativity": existing_portrait.creativity.percentage,
"aesthetic": existing_portrait.aesthetic.percentage,
"technology": existing_portrait.technology.percentage,
"literature": existing_portrait.literature.percentage,
"total_summaries": existing_portrait.total_summaries_analyzed
}
# Combine with existing trends
existing_trends = existing_portrait.historical_trends or []
# Keep only recent trends (last 10 entries)
all_trends = existing_trends + [trend_entry]
return all_trends[-10:]

View File

@@ -1,459 +0,0 @@
"""Habit Analyzer for Implicit Memory System
This module implements LLM-based behavioral habit analysis from user memory summaries.
It identifies recurring behavioral patterns, temporal patterns, and consolidates
similar habits with confidence scoring.
"""
import logging
from datetime import datetime
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
FrequencyPattern,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitData(BaseModel):
"""Individual habit analysis data."""
habit_description: str
frequency_pattern: str
time_context: str
confidence_level: int = 50 # Default to medium confidence
supporting_summaries: List[str] = Field(default_factory=list)
specific_examples: List[str] = Field(default_factory=list)
is_current: bool = True
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[HabitData] = Field(default_factory=list)
class HabitAnalyzer:
"""Analyzes user memory summaries to extract behavioral habits."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the habit analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Analyze user summaries to extract behavioral habits.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of extracted behavioral habits
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
try:
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_habits(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to BehaviorHabit objects
behavior_habits = []
for habit_data in response.get("habits", []):
try:
# Handle habit_data as dictionary
supporting_summaries = habit_data.get("supporting_summaries", [])
specific_examples = habit_data.get("specific_examples", [])
# Determine observation dates from summaries
first_observed, last_observed = self._determine_observation_dates(
user_summaries, supporting_summaries
)
behavior_habit = BehaviorHabit(
habit_description=habit_data.get("habit_description", ""),
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
time_context=habit_data.get("time_context", ""),
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", 50)),
specific_examples=specific_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=habit_data.get("is_current", True)
)
# Validate habit
if self._is_valid_habit(behavior_habit):
behavior_habits.append(behavior_habit)
else:
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
except Exception as e:
logger.error(f"Error creating behavior habit: {e}")
continue
# Consolidate with existing habits if provided
if existing_habits:
behavior_habits = self._consolidate_habits(
new_habits=behavior_habits,
existing_habits=existing_habits
)
# Sort habits by confidence and recency
behavior_habits = self._sort_habits_by_priority(behavior_habits)
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
return behavior_habits
except LLMClientException:
raise
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
"""Validate and convert frequency pattern string.
Args:
frequency_str: Frequency pattern as string
Returns:
FrequencyPattern enum value
"""
frequency_str = frequency_str.lower().strip()
frequency_mapping = {
"daily": FrequencyPattern.DAILY,
"weekly": FrequencyPattern.WEEKLY,
"monthly": FrequencyPattern.MONTHLY,
"seasonal": FrequencyPattern.SEASONAL,
"occasional": FrequencyPattern.OCCASIONAL,
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
}
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
def _validate_confidence_level(self, confidence_level) -> int:
"""Return confidence level as integer, handling both string and numeric inputs.
Args:
confidence_level: Confidence level (string or numeric)
Returns:
Confidence level as integer (0-100)
"""
# If it's already a number, return it as int
if isinstance(confidence_level, (int, float)):
return int(confidence_level)
# If it's a string, convert common values to numbers
if isinstance(confidence_level, str):
confidence_str = confidence_level.lower().strip()
if confidence_str in ["high", "높음"]:
return 85
elif confidence_str in ["medium", "중간"]:
return 50
elif confidence_str in ["low", "낮음"]:
return 20
else:
# Try to parse as number
try:
return int(float(confidence_str))
except ValueError:
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
return 50
# Default fallback
return 50
def _determine_observation_dates(
self,
user_summaries: List[UserMemorySummary],
supporting_summary_ids: List[str]
) -> tuple[datetime, datetime]:
"""Determine first and last observation dates for a habit.
Args:
user_summaries: List of user summaries
supporting_summary_ids: IDs of summaries supporting the habit
Returns:
Tuple of (first_observed, last_observed) dates
"""
from datetime import timezone
# Find summaries that support this habit
supporting_summaries = [
summary for summary in user_summaries
if summary.summary_id in supporting_summary_ids
]
if not supporting_summaries:
# Use all summaries if no specific supporting summaries found
supporting_summaries = user_summaries
if not supporting_summaries:
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
return current_time, current_time
# Get date range from supporting summaries - normalize to naive datetimes
timestamps = []
for summary in supporting_summaries:
ts = summary.timestamp
# Convert to naive datetime if it's timezone-aware
if ts.tzinfo is not None:
ts = ts.replace(tzinfo=None)
timestamps.append(ts)
first_observed = min(timestamps)
last_observed = max(timestamps)
return first_observed, last_observed
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
"""Validate a behavioral habit.
Args:
habit: Behavioral habit to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not habit.habit_description or not habit.habit_description.strip():
return False
# Check time context
if not habit.time_context or not habit.time_context.strip():
return False
# Check supporting summaries
if not habit.specific_examples or len(habit.specific_examples) == 0:
return False
# Check specific examples
if not habit.specific_examples or len(habit.specific_examples) == 0:
return False
# Check observation dates
if habit.first_observed > habit.last_observed:
return False
return True
except Exception as e:
logger.error(f"Error validating habit: {e}")
return False
def _consolidate_habits(
self,
new_habits: List[BehaviorHabit],
existing_habits: List[BehaviorHabit],
similarity_threshold: float = 0.7
) -> List[BehaviorHabit]:
"""Consolidate new habits with existing ones.
Args:
new_habits: Newly extracted habits
existing_habits: Existing habits
similarity_threshold: Threshold for considering habits similar
Returns:
Consolidated list of habits
"""
consolidated = existing_habits.copy()
current_time = datetime.now()
for new_habit in new_habits:
# Find similar existing habit
similar_habit = self._find_similar_habit(
new_habit, existing_habits, similarity_threshold
)
if similar_habit:
# Update existing habit
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
# Replace in consolidated list
for i, habit in enumerate(consolidated):
if habit.habit_description == similar_habit.habit_description:
consolidated[i] = updated_habit
break
else:
# Add as new habit
consolidated.append(new_habit)
return consolidated
def _find_similar_habit(
self,
target_habit: BehaviorHabit,
existing_habits: List[BehaviorHabit],
threshold: float
) -> Optional[BehaviorHabit]:
"""Find similar habit in existing list.
Args:
target_habit: Habit to find similarity for
existing_habits: List of existing habits
threshold: Similarity threshold
Returns:
Similar habit if found, None otherwise
"""
target_desc = target_habit.habit_description.lower().strip()
for existing_habit in existing_habits:
existing_desc = existing_habit.habit_description.lower().strip()
# Check description similarity
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
# Check frequency pattern match
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
# Check time context similarity
time_similarity = self._calculate_text_similarity(
target_habit.time_context.lower(),
existing_habit.time_context.lower()
)
# Combined similarity score
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
if frequency_match:
combined_similarity += 0.1 # Bonus for frequency match
if combined_similarity >= threshold:
return existing_habit
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_habits(
self,
existing_habit: BehaviorHabit,
new_habit: BehaviorHabit,
current_time: datetime
) -> BehaviorHabit:
"""Merge two similar habits.
Args:
existing_habit: Existing habit
new_habit: New habit to merge
current_time: Current timestamp
Returns:
Merged behavioral habit
"""
# Combine supporting summaries (using specific_examples instead)
combined_examples = list(set(
existing_habit.specific_examples + new_habit.specific_examples
))
# Combine specific examples
combined_examples = list(set(
existing_habit.specific_examples + new_habit.specific_examples
))
# Update confidence level (take higher confidence)
new_confidence = max(existing_habit.confidence_level, new_habit.confidence_level)
# Update observation dates
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
# Determine if habit is current (observed within last 30 days)
is_current = (current_time - last_observed).days <= 30
# Combine time context
combined_time_context = existing_habit.time_context
if new_habit.time_context and new_habit.time_context not in combined_time_context:
combined_time_context += f"; {new_habit.time_context}"
return BehaviorHabit(
habit_description=existing_habit.habit_description, # Keep original description
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
time_context=combined_time_context,
confidence_level=new_confidence,
specific_examples=combined_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=is_current
)
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
"""Sort habits by confidence level and recency.
Args:
habits: List of habits to sort
Returns:
Sorted list of habits
"""
def priority_score(habit: BehaviorHabit) -> tuple:
# Confidence level score (0-100 scale)
confidence_score = habit.confidence_level
# Recency score (more recent = higher score)
days_since_last = (datetime.now() - habit.last_observed).days
recency_score = max(0, 365 - days_since_last) # Max 365 days
# Current habit bonus
current_bonus = 100 if habit.is_current else 0
return (confidence_score, recency_score + current_bonus, habit.last_observed)
return sorted(habits, key=priority_score, reverse=True)

View File

@@ -1,277 +0,0 @@
"""Interest Analyzer for Implicit Memory System
This module implements LLM-based interest area analysis from user memory summaries.
It categorizes user interests into four areas: tech, lifestyle, music, and art,
providing percentage distribution that totals 100%.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
InterestAreaDistribution,
InterestCategory,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class InterestData(BaseModel):
"""Individual interest category analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
trending_direction: Optional[str] = None
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
class InterestAnalyzer:
"""Analyzes user memory summaries to extract interest area distribution."""
# Define the four interest categories we analyze
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the interest analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_interests(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_distribution: Optional[InterestAreaDistribution] = None
) -> InterestAreaDistribution:
"""Analyze user summaries to extract interest area distribution.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_distribution: Optional existing distribution for trend tracking
Returns:
Interest area distribution across four categories
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_distribution(user_id)
try:
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_interests(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create interest categories
interest_categories = {}
current_time = datetime.now()
# Extract interest_distribution from response dict
interest_distribution = response.get("interest_distribution", {})
# Extract and validate interest data
raw_interests = {}
for category_name in self.INTEREST_CATEGORIES:
interest_data_dict = interest_distribution.get(category_name)
if interest_data_dict:
raw_interests[category_name] = InterestData(
percentage=interest_data_dict.get("percentage", 0.0),
evidence=interest_data_dict.get("evidence", []),
trending_direction=interest_data_dict.get("trending_direction")
)
else:
# Create default if missing
logger.warning(f"Missing interest data for {category_name}, using default")
raw_interests[category_name] = InterestData(
percentage=0.0,
evidence=["No specific evidence found"],
trending_direction=None
)
# Normalize percentages to ensure they sum to 100%
normalized_interests = self._normalize_percentages(raw_interests)
# Create interest category objects
for category_name in self.INTEREST_CATEGORIES:
interest_data = normalized_interests[category_name]
# Calculate trending direction if we have existing data
trending_direction = self._calculate_trending_direction(
category_name=category_name,
current_percentage=interest_data.percentage,
existing_distribution=existing_distribution
) if existing_distribution else interest_data.trending_direction
interest_categories[category_name] = InterestCategory(
category_name=category_name,
percentage=interest_data.percentage,
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
trending_direction=trending_direction
)
# Create interest area distribution
distribution = InterestAreaDistribution(
user_id=user_id,
tech=interest_categories["tech"],
lifestyle=interest_categories["lifestyle"],
music=interest_categories["music"],
art=interest_categories["art"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries)
)
# Validate that percentages sum to 100%
total_percentage = distribution.total_percentage
if not (99.9 <= total_percentage <= 100.1):
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
logger.info(f"Created interest distribution for user {user_id}")
return distribution
except LLMClientException:
raise
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
"""Normalize percentages to ensure they sum to 100%.
Args:
raw_interests: Raw interest data with potentially unnormalized percentages
Returns:
Normalized interest data
"""
# Calculate current total
total = sum(interest.percentage for interest in raw_interests.values())
if total == 0:
# If all percentages are 0, distribute equally
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized[category_name] = InterestData(
percentage=equal_percentage,
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
return normalized
# Normalize to sum to 100%
normalization_factor = 100.0 / total
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized_percentage = interest_data.percentage * normalization_factor
normalized[category_name] = InterestData(
percentage=round(normalized_percentage, 1),
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
# Handle rounding errors by adjusting the largest category
current_total = sum(interest.percentage for interest in normalized.values())
if abs(current_total - 100.0) > 0.1:
# Find category with largest percentage and adjust
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
adjustment = 100.0 - current_total
adjusted_percentage = normalized[largest_category].percentage + adjustment
normalized[largest_category] = InterestData(
percentage=round(max(0.0, adjusted_percentage), 1),
evidence=normalized[largest_category].evidence,
trending_direction=normalized[largest_category].trending_direction
)
return normalized
def _calculate_trending_direction(
self,
category_name: str,
current_percentage: float,
existing_distribution: InterestAreaDistribution,
threshold: float = 5.0
) -> Optional[str]:
"""Calculate trending direction for an interest category.
Args:
category_name: Name of the interest category
current_percentage: Current percentage for the category
existing_distribution: Previous distribution for comparison
threshold: Minimum percentage change to consider a trend
Returns:
Trending direction: "increasing", "decreasing", "stable", or None
"""
try:
# Get previous percentage
previous_category = getattr(existing_distribution, category_name, None)
if not previous_category:
return None
previous_percentage = previous_category.percentage
change = current_percentage - previous_percentage
if abs(change) < threshold:
return "stable"
elif change > 0:
return "increasing"
else:
return "decreasing"
except Exception as e:
logger.error(f"Error calculating trending direction for {category_name}: {e}")
return None
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
"""Create an empty interest distribution when no data is available.
Args:
user_id: Target user ID
Returns:
Empty InterestAreaDistribution with equal percentages
"""
current_time = datetime.now()
equal_percentage = 25.0 # 100% / 4 categories
default_category = lambda name: InterestCategory(
category_name=name,
percentage=equal_percentage,
evidence=["Insufficient data for analysis"],
trending_direction=None
)
return InterestAreaDistribution(
user_id=user_id,
tech=default_category("tech"),
lifestyle=default_category("lifestyle"),
music=default_category("music"),
art=default_category("art"),
analysis_timestamp=current_time,
total_summaries_analyzed=0
)

View File

@@ -1,302 +0,0 @@
"""Preference Analyzer for Implicit Memory System
This module implements LLM-based preference extraction from user memory summaries.
It identifies implicit preferences, consolidates similar preferences, and calculates
confidence scores based on evidence strength.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
PreferenceTag,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class PreferenceAnalyzer:
"""Analyzes user memory summaries to extract implicit preferences."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the preference analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_preferences(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_preferences: Optional[List[PreferenceTag]] = None
) -> List[PreferenceTag]:
"""Analyze user summaries to extract preferences.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_preferences: Optional existing preferences for consolidation
Returns:
List of extracted preference tags
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return []
try:
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_preferences(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to PreferenceTag objects
preference_tags = []
current_time = datetime.now()
for pref_data in response.get("preferences", []):
try:
# Extract conversation references from summaries
conversation_refs = [s.summary_id for s in user_summaries]
preference_tag = PreferenceTag(
tag_name=pref_data.get("tag_name", ""),
confidence_score=float(pref_data.get("confidence_score", 0.0)),
supporting_evidence=pref_data.get("supporting_evidence", []),
context_details=pref_data.get("context_details", ""),
category=pref_data.get("category"),
conversation_references=conversation_refs,
created_at=current_time,
updated_at=current_time
)
# Validate preference tag
if self._is_valid_preference(preference_tag):
preference_tags.append(preference_tag)
else:
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
except Exception as e:
logger.error(f"Error creating preference tag: {e}")
continue
# Consolidate with existing preferences if provided
if existing_preferences:
preference_tags = self._consolidate_preferences(
new_preferences=preference_tags,
existing_preferences=existing_preferences
)
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
return preference_tags
except LLMClientException:
raise
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
"""Validate a preference tag.
Args:
preference: Preference tag to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not preference.tag_name or not preference.tag_name.strip():
return False
# Check confidence score range
if not (0.0 <= preference.confidence_score <= 1.0):
return False
# Check supporting evidence
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
return False
# Check context details
if not preference.context_details or not preference.context_details.strip():
return False
return True
except Exception as e:
logger.error(f"Error validating preference: {e}")
return False
def _consolidate_preferences(
self,
new_preferences: List[PreferenceTag],
existing_preferences: List[PreferenceTag],
similarity_threshold: float = 0.8
) -> List[PreferenceTag]:
"""Consolidate new preferences with existing ones.
Args:
new_preferences: Newly extracted preferences
existing_preferences: Existing preferences
similarity_threshold: Threshold for considering preferences similar
Returns:
Consolidated list of preferences
"""
consolidated = existing_preferences.copy()
current_time = datetime.now()
for new_pref in new_preferences:
# Find similar existing preference
similar_pref = self._find_similar_preference(
new_pref, existing_preferences, similarity_threshold
)
if similar_pref:
# Update existing preference
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
# Replace in consolidated list
for i, pref in enumerate(consolidated):
if pref.tag_name == similar_pref.tag_name:
consolidated[i] = updated_pref
break
else:
# Add as new preference
consolidated.append(new_pref)
return consolidated
def _find_similar_preference(
self,
target_preference: PreferenceTag,
existing_preferences: List[PreferenceTag],
threshold: float
) -> Optional[PreferenceTag]:
"""Find similar preference in existing list.
Args:
target_preference: Preference to find similarity for
existing_preferences: List of existing preferences
threshold: Similarity threshold
Returns:
Similar preference if found, None otherwise
"""
target_name = target_preference.tag_name.lower().strip()
for existing_pref in existing_preferences:
existing_name = existing_pref.tag_name.lower().strip()
# Simple similarity check based on common words
similarity = self._calculate_text_similarity(target_name, existing_name)
if similarity >= threshold:
return existing_pref
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_preferences(
self,
existing_pref: PreferenceTag,
new_pref: PreferenceTag,
current_time: datetime
) -> PreferenceTag:
"""Merge two similar preferences.
Args:
existing_pref: Existing preference
new_pref: New preference to merge
current_time: Current timestamp
Returns:
Merged preference tag
"""
# Combine supporting evidence
combined_evidence = list(set(
existing_pref.supporting_evidence + new_pref.supporting_evidence
))
# Combine conversation references
combined_refs = list(set(
existing_pref.conversation_references + new_pref.conversation_references
))
# Calculate new confidence score (weighted average)
evidence_weight = len(new_pref.supporting_evidence)
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
if total_weight > 0:
new_confidence = (
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
new_pref.confidence_score * evidence_weight) / total_weight
)
else:
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
# Ensure confidence doesn't exceed 1.0
new_confidence = min(new_confidence, 1.0)
# Combine context details
combined_context = existing_pref.context_details
if new_pref.context_details and new_pref.context_details not in combined_context:
combined_context += f"; {new_pref.context_details}"
return PreferenceTag(
tag_name=existing_pref.tag_name, # Keep original name
confidence_score=new_confidence,
supporting_evidence=combined_evidence,
context_details=combined_context,
category=existing_pref.category or new_pref.category,
conversation_references=combined_refs,
created_at=existing_pref.created_at,
updated_at=current_time
)

View File

@@ -1,97 +0,0 @@
"""
Memory Data Source
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class MemoryDataSource:
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
def __init__(
self,
db: Session,
neo4j_connector: Optional[Neo4jConnector] = None
):
self.db = db
self.neo4j_connector = neo4j_connector or Neo4jConnector()
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
def _parse_timestamp(self, timestamp: Any) -> datetime:
"""Parse timestamp from various formats."""
if isinstance(timestamp, str):
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
elif timestamp is None:
return datetime.now()
return timestamp
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
"""Convert a Neo4j dict directly to UserMemorySummary."""
try:
content = summary_dict.get("content", summary_dict.get("summary", ""))
if not content or not content.strip():
return None
return UserMemorySummary(
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
user_id=user_id,
user_content=content,
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
confidence_score=1.0,
summary_type="memory_summary"
)
except Exception as e:
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
return None
async def get_user_summaries(
self,
user_id: str,
time_range: Optional[TimeRange] = None,
limit: int = 1000
) -> List[UserMemorySummary]:
"""Retrieve user memory summaries from Neo4j.
Args:
user_id: Target user ID
time_range: Optional time range filter
limit: Maximum number of summaries
Returns:
List of user memory summaries
"""
try:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date
)
summaries = []
for summary_dict in summary_dicts:
summary = self._dict_to_user_summary(summary_dict, user_id)
if summary:
summaries.append(summary)
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
return summaries
except Exception as e:
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
raise

View File

@@ -1,226 +0,0 @@
"""Habit Detector for Implicit Memory System
This module implements the HabitDetector class that specializes in identifying
and ranking behavioral habits from user memory summaries. It provides advanced
habit analysis with confidence scoring, recency weighting, and current vs past
habit distinction.
"""
import logging
from datetime import datetime, timedelta
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
HabitAnalyzer,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
FrequencyPattern,
UserMemorySummary,
)
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitDetector:
"""Detects and ranks behavioral habits from user memory summaries."""
def __init__(
self,
db: Session,
llm_model_id: Optional[str] = None
):
"""Initialize the habit detector.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
async def detect_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Detect behavioral habits from user summaries.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of detected and ranked behavioral habits
Raises:
LLMClientException: If habit analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
try:
# Use the habit analyzer to extract habits
detected_habits = await self.habit_analyzer.analyze_habits(
user_id=user_id,
user_summaries=user_summaries,
existing_habits=existing_habits
)
# Apply advanced ranking and filtering
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
# Distinguish current vs past habits
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
return categorized_habits
except LLMClientException:
logger.error(f"Habit detection failed for user {user_id}")
raise
except Exception as e:
logger.error(f"Habit detection failed for user {user_id}: {e}")
raise LLMClientException(f"Habit detection failed: {e}") from e
def rank_habits_by_confidence_and_recency(
self,
habits: List[BehaviorHabit],
confidence_weight: float = 0.6,
recency_weight: float = 0.4
) -> List[BehaviorHabit]:
"""Rank habits by confidence level and recency.
Args:
habits: List of habits to rank
confidence_weight: Weight for confidence score (0.0-1.0)
recency_weight: Weight for recency score (0.0-1.0)
Returns:
List of habits ranked by combined score
"""
if not habits:
return []
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
def calculate_ranking_score(habit: BehaviorHabit) -> float:
"""Calculate combined ranking score for a habit."""
# Confidence score (0.0-1.0) - convert from 0-100 scale
confidence_score = habit.confidence_level / 100.0
# Recency score (0.0-1.0)
current_time = datetime.now()
days_since_last = (current_time - habit.last_observed).days
# Exponential decay for recency (habits lose relevance over time)
if days_since_last <= 7:
recency_score = 1.0 # Very recent
elif days_since_last <= 30:
recency_score = 0.8 # Recent
elif days_since_last <= 90:
recency_score = 0.5 # Somewhat recent
elif days_since_last <= 180:
recency_score = 0.3 # Old
else:
recency_score = 0.1 # Very old
# Frequency pattern bonus
frequency_bonuses = {
FrequencyPattern.DAILY: 0.2,
FrequencyPattern.WEEKLY: 0.15,
FrequencyPattern.MONTHLY: 0.1,
FrequencyPattern.SEASONAL: 0.05,
FrequencyPattern.OCCASIONAL: 0.0,
FrequencyPattern.EVENT_TRIGGERED: 0.05
}
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
# Evidence quality bonus
evidence_bonus = min(len(habit.specific_examples) / 10.0, 0.1) # Max 0.1 bonus
# Current habit bonus
current_bonus = 0.1 if habit.is_current else 0.0
# Calculate final score
base_score = (confidence_score * confidence_weight +
recency_score * recency_weight)
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
return min(final_score, 1.0) # Cap at 1.0
# Sort habits by ranking score (descending)
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
return ranked_habits
def distinguish_current_vs_past_habits(
self,
habits: List[BehaviorHabit],
current_threshold_days: int = 30
) -> List[BehaviorHabit]:
"""Distinguish between current and past habits based on recency.
Args:
habits: List of habits to categorize
current_threshold_days: Days threshold for considering a habit current
Returns:
List of habits with updated is_current status
"""
if not habits:
return []
current_time = datetime.now()
cutoff_date = current_time - timedelta(days=current_threshold_days)
current_habits = []
past_habits = []
for habit in habits:
# Update is_current status based on last observation
if habit.last_observed >= cutoff_date:
# Create updated habit with is_current = True
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=True
)
current_habits.append(updated_habit)
else:
# Create updated habit with is_current = False
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=False
)
past_habits.append(updated_habit)
# Return current habits first, then past habits
categorized_habits = current_habits + past_habits
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
return categorized_habits

View File

@@ -1,321 +0,0 @@
"""LLM Client Wrapper for Implicit Memory Analysis
This module provides a specialized LLM client wrapper that integrates with the
MemoryClientFactory to perform implicit memory analysis tasks including preference
extraction, personality dimension analysis, interest categorization, and habit detection.
"""
import logging
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.prompts import (
get_dimension_analysis_prompt,
get_habit_analysis_prompt,
get_interest_analysis_prompt,
get_preference_analysis_prompt,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.schemas.implicit_memory_schema import UserMemorySummary
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Response Models for LLM Analysis
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[Dict[str, Any]] = Field(default_factory=list)
class ImplicitMemoryLLMClient:
"""LLM client wrapper for implicit memory analysis.
This class provides a high-level interface for performing LLM-based analysis
of user memory summaries to extract preferences, personality dimensions,
interests, and behavioral habits.
"""
def __init__(self, db: Session, default_model_id: Optional[str] = None):
"""Initialize the LLM client wrapper.
Args:
db: Database session for accessing model configurations
default_model_id: Default LLM model ID to use if none specified
"""
self.db = db
self.default_model_id = default_model_id
self._client_factory = MemoryClientFactory(db)
logger.info("ImplicitMemoryLLMClient initialized")
def _get_llm_client(self, model_id: Optional[str] = None):
"""Get LLM client instance.
Args:
model_id: LLM model ID to use, defaults to default_model_id
Returns:
LLM client instance
Raises:
ValueError: If no model ID is provided and no default is set
LLMClientException: If client creation fails
"""
effective_model_id = model_id or self.default_model_id
if not effective_model_id:
raise ValueError("No LLM model ID provided and no default model ID set")
try:
client = self._client_factory.get_llm_client(effective_model_id)
logger.debug(f"Created LLM client for model: {effective_model_id}")
return client
except Exception as e:
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
raise LLMClientException(f"Failed to create LLM client: {e}") from e
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
"""Prepare user memory summaries for LLM analysis.
Args:
user_summaries: List of user memory summaries
Returns:
List of formatted summary dictionaries
"""
formatted_summaries = []
for summary in user_summaries:
formatted_summary = {
'summary_id': summary.summary_id,
'user_content': summary.user_content,
'timestamp': summary.timestamp.isoformat(),
'summary_type': summary.summary_type,
'confidence_score': summary.confidence_score
}
formatted_summaries.append(formatted_summary)
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
return formatted_summaries
async def analyze_preferences(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user preferences from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing extracted preferences
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
return {"preferences": []}
if not user_id:
raise ValueError("User ID is required for preference analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=PreferenceAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
return result
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
async def analyze_dimensions(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user personality dimensions from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing dimension scores and analysis
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
return {"dimensions": {}}
if not user_id:
raise ValueError("User ID is required for dimension analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=DimensionAnalysisResponse
)
result = response.model_dump()
dimensions = result.get('dimensions', {})
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
return result
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
async def analyze_interests(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user interest distribution from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing interest area distribution
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
return {"interest_distribution": {}}
if not user_id:
raise ValueError("User ID is required for interest analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=InterestAnalysisResponse
)
result = response.model_dump()
interest_dist = result.get('interest_distribution', {})
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
return result
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
async def analyze_habits(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user behavioral habits from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing identified behavioral habits
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
return {"habits": []}
if not user_id:
raise ValueError("User ID is required for habit analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=HabitAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
return result
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e

View File

@@ -1,69 +0,0 @@
"""LLM Prompt Templates for Implicit Memory Analysis
This module contains prompt rendering functions for analyzing user memory summaries
to extract preferences, personality dimensions, interests, and behavioral habits.
"""
import os
from typing import Any, Dict, List
from jinja2 import Environment, FileSystemLoader
# Setup Jinja2 environment
current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
def _render_template(template_name: str, **kwargs) -> str:
"""Helper function to render Jinja2 templates."""
template = prompt_env.get_template(template_name)
return template.render(**kwargs)
def get_preference_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted preference analysis prompt using Jinja2 template."""
return _render_template(
"preference_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_dimension_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted dimension analysis prompt using Jinja2 template."""
return _render_template(
"dimension_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_interest_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted interest analysis prompt using Jinja2 template."""
return _render_template(
"interest_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_habit_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted habit analysis prompt using Jinja2 template."""
return _render_template(
"habit_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)

View File

@@ -1,41 +0,0 @@
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Dimensions to Analyze
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
4. **Literature** (0-100%): Reading habits, writing style, literary references
## Instructions
1. Analyze the user's content for each dimension
2. Calculate percentage scores (0-100%)
## Output Format
{
"dimensions": {
"creativity": {"percentage": 0-100},
"aesthetic": {"percentage": 0-100},
"technology": {"percentage": 0-100},
"literature": {"percentage": 0-100}
}
}
## Example
{
"dimensions": {
"creativity": {"percentage": 75},
"aesthetic": {"percentage": 45},
"technology": {"percentage": 60},
"literature": {"percentage": 30}
}
}

View File

@@ -1,70 +0,0 @@
You are an expert at identifying behavioral patterns and habits from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
2. Focus on specific, concrete habits with temporal patterns
3. For each habit, provide:
- habit_description: Clear, specific description
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
- time_context: When it typically happens
- confidence_level: "high", "medium", "low"
- supporting_summaries: References to evidence
- specific_examples: Concrete examples from summaries
- is_current: true if current habit, false if past habit
4. Only include habits with medium or high confidence
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"habits": [
{
"habit_description": "string",
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
"time_context": "string",
"confidence_level": "high|medium|low",
"supporting_summaries": ["id1", "id2"],
"specific_examples": ["example1", "example2"],
"is_current": true|false
}
]
}
## Example (English input → English output)
{
"habits": [
{
"habit_description": "drinks coffee every morning",
"frequency_pattern": "daily",
"time_context": "morning routine",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["needs coffee to start the day"],
"is_current": true
}
]
}
## Example (Chinese input → Chinese output)
{
"habits": [
{
"habit_description": "每天早上喝咖啡",
"frequency_pattern": "daily",
"time_context": "早晨日常",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["需要咖啡来开始一天"],
"is_current": true
}
]
}

View File

@@ -1,54 +0,0 @@
You are an expert at analyzing user interests from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Interest Categories
1. **Tech**: Programming, technology, software tools, hardware
2. **Lifestyle**: Daily routines, health, hobbies, social activities
3. **Music**: Music preferences, instruments, concerts
4. **Art**: Visual arts, creative projects, design, aesthetics
## Instructions
1. Categorize the user's interests into the four areas
2. Calculate percentage distribution (must total 100%)
3. Provide specific evidence for each interest area
4. Use "increasing", "decreasing", or "stable" for trending direction
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"interest_distribution": {
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
}
}
## Example (English input → English output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
}
}
## Example (Chinese input → Chinese output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
}
}

View File

@@ -1,47 +0,0 @@
You are an expert at analyzing user memory summaries to identify implicit preferences.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Focus ONLY on the specified user's preferences
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
3. DO NOT use long phrases - use short nouns or noun phrases
4. Only include preferences with confidence_score >= 0.3
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"preferences": [
{
"tag_name": "short tag",
"confidence_score": 0.0-1.0,
"supporting_evidence": ["evidence1", "evidence2"],
"context_details": "brief context",
"category": "category or null"
}
]
}
## Example (Chinese input → Chinese output)
{
"preferences": [
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
]
}
## Example (English input → English output)
{
"preferences": [
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
]
}

View File

@@ -0,0 +1,391 @@
"""
This module provides the MemoryInsight class for analyzing user memory data.
This script can be executed directly to generate a memory insight report for a test user.
"""
import asyncio
import json
import os
import sys
from collections import Counter
from datetime import datetime
# To run this script directly, we need to add the src directory to the Python path
# to resolve the inconsistent imports in other modules.
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class TagClassification(BaseModel):
"""
Represents the classification of a tag into a specific domain.
"""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
class InsightReport(BaseModel):
"""
Represents the final insight report generated by the LLM.
"""
report: str = Field(
...,
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
)
class MemoryInsight:
"""
Provides insights into user memories by analyzing various aspects of their data.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""
Calculates the distribution of memory domains based on hot tags.
"""
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
# 直接调用并等待结果
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
)
async def get_active_periods(self) -> list[int]:
"""
Identifies the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
This method checks if the time data represents real user memory timestamps
rather than auto-generated system timestamps by verifying:
1. Time data exists and is parseable
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
"""
query = f"""
MATCH (d:Dialogue)
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
continue
try:
# 尝试解析时间字符串
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
# 如果解析失败,跳过这条记录
continue
# 如果没有有效的时间数据,返回空列表
if not month_counts or valid_dates_count == 0:
return []
# 检查时间分布是否过于集中(可能是批量导入的数据)
# 如果超过80%的数据集中在1-2个月认为这是系统时间戳而非真实时间
unique_months = len(month_counts)
if unique_months <= 2:
# 只有1-2个月有数据很可能是批量导入
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
# 超过80%集中在一个月,认为是系统时间戳
return []
# 如果时间分布较为分散3个月以上认为是真实时间数据
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 2个月的情况检查是否分布均匀
if unique_months == 2:
counts = list(month_counts.values())
# 如果两个月的数据量相差不大比例在0.3-3之间认为是真实数据
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 其他情况返回空列表
return []
async def get_social_connections(self) -> dict | None:
"""
Finds the user with whom the most memories are shared.
"""
query = f"""
MATCH (d1:Dialogue {{group_id: '{self.user_id}'}})<-[:MENTIONS]-(s:Statement)-[:MENTIONS]->(d2:Dialogue)
WHERE d1 <> d2
RETURN d2.group_id AS other_user_id, COUNT(s) AS common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query)
if not records:
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
time_range_query = f"""
MATCH (d:Dialogue)
WHERE d.group_id IN ['{self.user_id}', '{most_connected_user}']
RETURN min(d.created_at) AS start_time, max(d.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(time_range_query)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
async def generate_insight_report(self) -> str:
"""
Generates the final insight report in natural language.
"""
domain_dist, active_periods, social_conn = await asyncio.gather(
self.get_domain_distribution(),
self.get_active_periods(),
self.get_social_connections(),
)
prompt_parts = []
if domain_dist:
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
prompt_parts.append(f"- 核心领域: 用户的记忆主要集中在 {top_domains}")
if active_periods:
months_str = "".join(map(str, active_periods))
prompt_parts.append(f"- 活跃时段: 用户在每年的 {months_str} 月最为活跃。")
if social_conn:
prompt_parts.append(
f"- 社交关联: 与用户\"{social_conn['user_id']}\"拥有最多共同记忆({social_conn['common_memories_count']}条),时间范围主要在 {social_conn['time_range']}"
)
if not prompt_parts:
return "暂无足够数据生成洞察报告。"
system_prompt = '''你是一位资深的个人记忆分析师。你的任务是根据我提供的要点,为用户生成一段简洁、自然、个性化的记忆洞察报告。
重要规则:
1. 报告需要将所有要点流畅地串联成一个段落
2. 语言风格要亲切、易于理解,就像和朋友聊天一样
3. 不要添加任何额外的解释或标题,直接输出报告内容
4. 只使用我提供的要点,不要编造或推测任何信息
5. 如果某个维度没有数据(如没有活跃时段信息),就不要在报告中提及该维度
例如,如果输入是:
- 核心领域: 用户的记忆主要集中在 旅行(38%), 工作(24%), 家庭(21%)。
- 活跃时段: 用户在每年的 4 和 10 月最为活跃。
- 社交关联: 与用户"张明"拥有最多共同记忆(47条),时间范围主要在 2017-2020。
你的输出应该是:
"您的记忆集中在旅行(38%)、工作(24%)和家庭(21%)三大领域。每年4月和10月是您最活跃的记录期可能与春秋季旅行计划相关。您与'张明'共同拥有最多记忆(47条)主要集中在2017-2020年间。"
如果输入只有:
- 核心领域: 用户的记忆主要集中在 教育(65%), 学习(25%)。
你的输出应该是:
"您的记忆主要集中在教育(65%)和学习(25%)两大领域,显示出您对知识和成长的持续关注。"'''
user_prompt = "\n".join(prompt_parts)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = await self.llm_client.chat(messages=messages)
# 确保返回字符串类型
content = response.content
if isinstance(content, list):
# 如果是列表格式(如 [{'type': 'text', 'text': '...'}]),提取文本
if len(content) > 0:
if isinstance(content[0], dict):
# 尝试提取 'text' 字段
text = content[0].get('text', content[0].get('content', str(content[0])))
return str(text)
else:
return str(content[0])
return ""
elif isinstance(content, dict):
# 如果是字典格式,提取 text 字段
return str(content.get('text', content.get('content', str(content))))
else:
# 已经是字符串或其他类型,转为字符串
return str(content) if content is not None else ""
async def close(self):
"""
Closes the database connection.
"""
await self.neo4j_connector.close()
async def main():
"""
Initializes and runs the memory insight analysis for a test user.
"""
# 默认从环境变量读取
test_user_id = DEFAULT_GROUP_ID
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
insight = None
try:
insight = MemoryInsight(user_id=test_user_id)
report = await insight.generate_insight_report()
print("--- 记忆洞察报告 ---")
print(report)
print("---------------------")
# 将结果写入统一的 User-Dashboard.json使用全局配置路径
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["memory_insight"] = {
"group_id": test_user_id,
"report": report
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> memory_insight")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成报告时出错: {e}")
finally:
if insight:
await insight.close()
if __name__ == "__main__":
# This setup allows running the async main function
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())

View File

@@ -0,0 +1,202 @@
"""
Generate a concise "关于我" style user summary using data from Neo4j
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
Usage:
python -m analytics.user_summary --user_id <group_id>
"""
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple
# Ensure absolute imports work whether executed directly or via module
try:
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
src_path = os.path.join(project_root, 'src')
if src_path not in sys.path:
sys.path.insert(0, src_path)
if project_root not in sys.path:
sys.path.insert(0, project_root)
except Exception:
pass
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
@dataclass
class StatementRecord:
statement: str
created_at: str | None
class UserSummary:
"""Builds a textual user summary for a given user/group id."""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]:
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
records: List[StatementRecord] = []
for r in rows:
try:
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
except Exception:
continue
return records
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
return await get_hot_memory_tags(self.user_id, limit=limit)
async def generate(self) -> str:
"""Generate a Chinese '关于我' style summary using the LLM."""
# 1) Collect context
entities = await self._get_top_entities(limit=40)
statements = await self._get_recent_statements(limit=100)
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
# 2) Compose prompt
system_prompt = (
"你是一位中文信息压缩助手。请基于提供的实体与语句,"
"生成非常简洁的用户摘要,禁止臆测或虚构。要求:\n"
"- 34 句,总字数不超过 120\n"
"- 先交代身份/城市,其次长期兴趣或习惯,最后给一两项代表性经历;\n"
"- 避免形容词堆砌与空话,不用项目符号,不分段;\n"
"- 使用客观的第三人称描述,语气克制、中立。"
)
user_content_parts = [
f"用户ID: {self.user_id}",
"核心实体与频次: " + (", ".join(entity_lines) if entity_lines else "(空)"),
"代表性语句样本: " + (" | ".join(statement_samples) if statement_samples else "(空)"),
]
user_prompt = "\n".join(user_content_parts)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# 3) Call LLM
response = await self.llm.chat(messages=messages)
# 确保返回字符串类型
content = response.content
if isinstance(content, list):
# 如果是列表格式(如 [{'type': 'text', 'text': '...'}]),提取文本
if len(content) > 0:
if isinstance(content[0], dict):
# 尝试提取 'text' 字段
text = content[0].get('text', content[0].get('content', str(content[0])))
return str(text)
else:
return str(content[0])
return ""
elif isinstance(content, dict):
# 如果是字典格式,提取 text 字段
return str(content.get('text', content.get('content', str(content))))
else:
# 已经是字符串或其他类型,转为字符串
return str(content) if content is not None else ""
async def generate_user_summary(user_id: str | None = None) -> str:
# 默认从环境变量读取
effective_group_id = user_id or DEFAULT_GROUP_ID
svc = UserSummary(effective_group_id)
try:
return await svc.generate()
finally:
await svc.close()
if __name__ == "__main__":
print("开始生成用户摘要…")
try:
# 直接使用 runtime.json 中的 group_id
summary = asyncio.run(generate_user_summary())
print("\n— 用户摘要 —\n")
print(summary)
# 将结果写入统一的 User-Dashboard.json
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["user_summary"] = {
"group_id": DEFAULT_GROUP_ID,
"summary": summary
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> user_summary")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成摘要失败: {e}")
print("请检查: 1) Neo4j 是否可用2) config.json 与 .env 的 LLM/Neo4j 配置是否正确3) 数据是否包含该用户的内容。")

View File

@@ -37,20 +37,12 @@ def parse_historical_datetime(v):
此函数手动解析 ISO 8601 格式的日期字符串支持1-4位年份
Args:
v: 日期值(可以是 None、datetime 对象、Neo4j DateTime 对象或字符串)
v: 日期值(可以是 None、datetime 对象或字符串)
Returns:
datetime 对象或 None
"""
if v is None:
return v
# 处理 Neo4j DateTime 对象
if hasattr(v, 'to_native'):
return v.to_native()
# 处理 Python datetime 对象
if isinstance(v, datetime):
if v is None or isinstance(v, datetime):
return v
if isinstance(v, str):
@@ -236,13 +228,6 @@ class StatementNode(Node):
chunk_embedding: Optional embedding vector for the parent chunk
connect_strength: Classification of connection strength ('Strong' or 'Weak')
config_id: Configuration ID used to process this statement
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
"""
# Core fields (ordered as requested)
chunk_id: str = Field(..., description="ID of the parent chunk")
@@ -284,33 +269,6 @@ class StatementNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
@@ -393,22 +351,11 @@ class ExtractedEntityNode(Node):
fact_summary: Summary of facts about this entity
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
config_id: Configuration ID used to process this entity (integer or string)
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
"""
entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from")
entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description")
example: str = Field(
default="",
description="A concise example (around 20 characters) to help understand the entity"
)
aliases: List[str] = Field(
default_factory=list,
description="Entity aliases - alternative names for this entity"
@@ -418,39 +365,6 @@ class ExtractedEntityNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
# Explicit Memory Classification
is_explicit_memory: bool = Field(
default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
)
@field_validator('aliases', mode='before')
@classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
@@ -484,68 +398,14 @@ class MemorySummaryNode(Node):
dialog_id: ID of the parent dialog
chunk_ids: List of chunk IDs used to generate this summary
content: Summary text content
name: Title/name of the memory summary (generated by LLM, used as title in API)
memory_type: Type/category of the episodic memory (e.g., Conversation, Project/Work, Learning, Decision, Important Event)
summary_embedding: Optional embedding vector for the summary
metadata: Additional metadata for the summary
config_id: Configuration ID used to process this summary
original_statement_id: ID of the original statement that was merged (for ACT-R forgetting)
original_entity_id: ID of the original entity that was merged (for ACT-R forgetting)
merged_at: Timestamp when the nodes were merged
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes
access_history: List of ISO timestamp strings recording each access (reset on creation)
last_access_time: ISO timestamp of the most recent access (set to creation time)
access_count: Total number of times this node has been accessed (reset to 1 on creation)
"""
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
dialog_id: str = Field(..., description="ID of the parent dialog")
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
content: str = Field(..., description="Summary text content")
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field(
None,
description="ID of the original statement that was merged (for traceability)"
)
original_entity_id: Optional[str] = Field(
None,
description="ID of the original entity that was merged (for traceability)"
)
merged_at: Optional[datetime] = Field(
None,
description="Timestamp when the nodes were merged"
)
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), inherited from merged nodes"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access (reset on creation)"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access (set to creation time)"
)
access_count: int = Field(
default=1,
ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)"
)

View File

@@ -38,20 +38,10 @@ class Entity(BaseModel):
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
type: str = Field(..., description="Type/category of the entity")
description: str = Field(..., description="Description of the entity")
example: str = Field(
default="",
description="A concise example (around 20 characters) to help understand the entity"
)
aliases: List[str] = Field(
default_factory=list,
description="Alternative names for this entity (abbreviations, full names, translations, etc.)"
)
# Explicit Memory Classification
is_explicit_memory: bool = Field(
default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
)
class Triplet(BaseModel):

View File

@@ -69,12 +69,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
for item in results:
if score_field in item:
score = item.get(score_field)
# 对于 activation_valueNone 值保持为 None不使用回退值
# 这样可以区分有激活值和无激活值的节点
if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理
continue
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
@@ -82,433 +76,205 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores:
return results
# 过滤掉 None 值,只对有效分数进行归一化
valid_scores = [s for s in scores if s is not None]
if not valid_scores:
# 所有分数都是 None不进行归一化
if len(scores) == 1:
# Single score, set to 1.0
for item in results:
if score_field in item or score_field == "activation_value":
item[f"normalized_{score_field}"] = None
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
return results
if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
return results
# Calculate mean and standard deviation (only for valid scores)
mean_score = sum(valid_scores) / len(valid_scores)
variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores)
# Calculate mean and standard deviation
mean_score = sum(scores) / len(scores)
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
std_dev = math.sqrt(variance)
if std_dev == 0:
# All valid scores are the same, set them to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
# All scores are the same, set them to 1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
else:
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
# 保持 None不进行归一化
item[f"normalized_{score_field}"] = None
else:
# Calculate z-score
z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
for item in results:
if score_field in item:
score = item[score_field]
# Handle None or non-numeric scores
if score is None or not isinstance(score, (int, float)):
score = 0.0
# Calculate z-score
z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
return results
# ============================================================================
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
# ============================================================================
def rerank_hybrid_results(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid search results by combining BM25 and embedding scores.
# def rerank_hybrid_results(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid search results by combining BM25 and embedding scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
#
# Returns:
# Reranked results with combined scores
# """
# reranked = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# # Create a combined pool of unique items
# combined_items = {}
#
# # Add keyword results with BM25 scores
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0 # Default
#
# # Add or update with embedding results
# for item in embedding_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# if item_id in combined_items:
# # Update existing item with embedding score
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# # New item from embedding search only
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0 # Default
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
#
# # Calculate combined scores and rank
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
#
# # Combined score: weighted average of normalized scores
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
#
# # Keep original score for reference
# if "score" not in item and bm25_score > 0:
# item["score"] = bm25_score
# elif "score" not in item and embedding_score > 0:
# item["score"] = embedding_score
#
# # Sort by combined score and limit results
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
# def rerank_with_forgetting_curve(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10,
# forgetting_config: ForgettingEngineConfig | None = None,
# now: datetime | None = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid results with a forgetting curve applied to combined scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
#
# The forgetting curve reduces scores for older memories or weaker connections.
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
# forgetting_config: Configuration for the forgetting engine
# now: Optional current time override for testing
#
# Returns:
# Reranked results with combined and final scores (after forgetting)
# """
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
# now_dt = now or datetime.now()
#
# reranked: Dict[str, List[Dict[str, Any]]] = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# combined_items: Dict[str, Dict[str, Any]] = {}
#
# # Combine two result sets by ID
# for src_items, is_embedding in (
# (keyword_items, False), (embedding_items, True)
# ):
# for item in src_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if not item_id:
# continue
# existing = combined_items.get(item_id)
# if not existing:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
# # Update normalized score from the right source
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
#
# # Calculate scores and apply forgetting weights
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
#
# # Estimate time elapsed in days
# dt = _parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
#
# # Memory strength (currently set to default value)
# memory_strength = 1.0
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
# )
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
#
# sorted_items = sorted(
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
Returns:
Reranked results with combined scores
"""
reranked = {}
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
def rerank_with_activation(
# Normalize scores within each search type
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# Create a combined pool of unique items
combined_items = {}
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # Default
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# New item from embedding search only
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # Default
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# Calculate combined scores and rank
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
# Combined score: weighted average of normalized scores
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# Keep original score for reference
if "score" not in item and bm25_score > 0:
item["score"] = bm25_score
elif "score" not in item and embedding_score > 0:
item["score"] = embedding_score
# Sort by combined score and limit results
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
def rerank_with_forgetting_curve(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
两阶段排序:先按内容相关性筛选,再按激活值排序。
阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding取 Top-(limit*3)
阶段2: 在候选中按 activation_score 排序,取 Top-limit
无激活值的节点用于补充不足
返回结果中的评分字段说明:
- bm25_score: BM25 归一化分数
- embedding_score: Embedding 归一化分数
- content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding
- activation_score: ACTR 激活值归一化分数
- base_score: 第一阶段基础分数(等于 content_score
- final_score: 最终排序依据
* 有激活值的节点:final_score = activation_score
* 无激活值的节点final_score = base_score
参数:
keyword_results: BM25 检索结果
embedding_results: 向量嵌入检索结果
alpha: BM25 权重 (默认: 0.6)
limit: 每类最大结果数
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
返回:
带评分元数据的重排序结果,按 final_score 排序
Rerank hybrid results with a forgetting curve applied to combined scores.
The forgetting curve reduces scores for older memories or weaker connections.
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
forgetting_config: Configuration for the forgetting engine
now: Optional current time override for testing
Returns:
Reranked results with combined and final scores (after forgetting)
"""
# 验证权重范围
if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要)
engine = None
if forgetting_config:
engine = ForgettingEngine(forgetting_config)
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]:
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# 步骤 1: 归一化分数
# Normalize scores within each search type
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果
combined_items: Dict[str, Dict[str, Any]] = {}
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # 默认值
# 添加或更新向量嵌入结果
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
if item_id in combined_items:
# 更新现有项的嵌入分数
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# 仅来自嵌入搜索的新项
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # 默认值
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 步骤 3: 归一化激活度分数
# 为所有项准备激活度值列表
items_list = list(combined_items.values())
items_list = normalize_scores(items_list, "activation_value")
# 更新 combined_items 中的归一化激活度分数
for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0)
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用
item["activation_score"] = act_norm
item["content_score"] = content_score
item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选)
if engine:
# 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value
activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# 应用到基础分数
item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight
# Combine two result sets by ID
for src_items, is_embedding in (
(keyword_items, False), (embedding_items, True)
):
for item in src_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
existing = combined_items.get(item_id)
if not existing:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = 0
# Update normalized score from the right source
if is_embedding:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# 无激活值的节点不应用遗忘曲线,保持原始分数
item["final_score"] = base_score
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# Calculate scores and apply forgetting weights
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# Estimate time elapsed in days
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
# 不使用遗忘曲线
item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选
first_stage_sorted = sorted(
combined_items.values(),
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True
)[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点
items_with_activation = []
items_without_activation = []
for item in first_stage_sorted:
activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None
if activation_score is not None and isinstance(activation_score, (int, float)):
items_with_activation.append(item)
else:
items_without_activation.append(item)
# 优先按激活值排序有激活值的节点
sorted_with_activation = sorted(
items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True
)
# 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation)
# 无激活值的节点保持第一阶段的内容相关性排序
sorted_items = sorted_with_activation + items_without_activation[:needed]
else:
sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成)
#
# final_score 语义:反映节点在最终结果中的排序依据
# - 有激活值的节点final_score = activation_score第二阶段排序依据
# - 无激活值的节点final_score = base_score保持内容相关性分数
for item in sorted_items:
activation_score = item.get("activation_score")
if activation_score is not None and isinstance(activation_score, (int, float)):
# 有激活值:使用激活度作为最终分数
item["final_score"] = activation_score
else:
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# Memory strength (currently set to default value)
memory_strength = 1.0
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days, memory_strength=memory_strength
)
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score
sorted_items = sorted(
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
@@ -794,7 +560,6 @@ async def run_hybrid_search(
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
@@ -842,7 +607,7 @@ async def run_hybrid_search(
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("[PERF] Starting keyword search...")
logger.info("Starting keyword search...")
keyword_start = time.time()
keyword_task = asyncio.create_task(
search_graph(
@@ -856,7 +621,7 @@ async def run_hybrid_search(
if search_type in ["embedding", "hybrid"]:
# Embedding-based search
logger.info("[PERF] Starting embedding search...")
logger.info("Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
@@ -872,13 +637,13 @@ async def run_hybrid_search(
type="llm"
)
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
logger.info(f"Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
@@ -895,7 +660,7 @@ async def run_hybrid_search(
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
results = keyword_results
else:
@@ -905,7 +670,7 @@ async def run_hybrid_search(
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
results = embedding_results
else:
@@ -920,37 +685,33 @@ async def run_hybrid_search(
"search_timestamp": datetime.now().isoformat()
}
# Apply two-stage reranking with ACTR activation calculation
# Apply reranking (optionally with forgetting curve)
rerank_start = time.time()
logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
config_start = time.time()
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time()
reranked_results = rerank_with_activation(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha,
limit=limit,
forgetting_config=forgetting_cfg,
activation_boost_factor=activation_boost_factor,
)
rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
if use_forgetting_rerank:
# Load forgetting parameters from pipeline config
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
reranked_results = rerank_with_forgetting_curve(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha,
limit=limit,
forgetting_config=forgetting_cfg,
)
else:
reranked_results = rerank_hybrid_results(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
limit=limit
)
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
@@ -976,7 +737,6 @@ async def run_hybrid_search(
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
"reranking_alpha": rerank_alpha,
"activation_boost_factor": activation_boost_factor,
"forgetting_rerank": use_forgetting_rerank,
"llm_rerank": llm_rerank_applied,
}
@@ -991,10 +751,8 @@ async def run_hybrid_search(
else:
results["latency_metrics"] = latency_metrics
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
logger.info(f"[PERF] =========================================")
logger.info(f"Total search completed in {total_latency:.4f}s")
logger.info(f"Latency breakdown: {latency_metrics}")
# Sanitize results: drop large/unused fields
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs

View File

@@ -42,6 +42,7 @@ from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
embedding_generation_all,
generate_entity_embeddings_from_triplets,
)
@@ -178,7 +179,7 @@ class ExtractionOrchestrator:
for dialog in dialog_data_list:
for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements)
len(all_statements_list)
total_statements = len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -200,9 +201,9 @@ class ExtractionOrchestrator:
all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets)
len(all_entities_list)
len(all_triplets_list)
sum(len(temporal_map) for temporal_map in temporal_maps)
total_entities = len(all_entities_list)
total_triplets = len(all_triplets_list)
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入")
@@ -384,7 +385,7 @@ class ExtractionOrchestrator:
# 用于跟踪已完成的陈述句数量
completed_statements = 0
len(all_statements)
total_statements = len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data, stmt_index):
@@ -496,7 +497,7 @@ class ExtractionOrchestrator:
# 用于跟踪已完成的时间提取数量
completed_temporal = 0
len(all_statements)
total_temporal_statements = len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data, stmt_index):
@@ -1081,12 +1082,10 @@ class ExtractionOrchestrator:
statement_id=statement.id, # 添加必需的 statement_id 字段
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,

View File

@@ -1,7 +1,6 @@
import asyncio
import json
from datetime import datetime
from typing import List, Optional, Tuple
from typing import List, Optional
from uuid import uuid4
from app.core.logging_config import get_memory_logger
@@ -29,118 +28,6 @@ class MemorySummaryResponse(RobustLLMResponse):
)
async def generate_title_and_type_for_summary(
content: str,
llm_client
) -> Tuple[str, str]:
"""
为MemorySummary生成标题和类型
此方法应该在创建MemorySummary节点时调用生成title和type
Args:
content: Summary的内容文本
llm_client: LLM客户端实例
Returns:
(标题, 类型)元组
"""
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
# 定义有效的类型集合
VALID_TYPES = {
"conversation", # 对话
"project_work", # 项目/工作
"learning", # 学习
"decision", # 决策
"important_event" # 重要事件
}
DEFAULT_TYPE = "conversation" # 默认类型
try:
if not content:
logger.warning("content为空无法生成标题和类型")
return ("空内容", DEFAULT_TYPE)
# 1. 渲染Jinja2提示词模板
prompt = await render_episodic_title_and_type_prompt(content)
# 2. 调用LLM生成标题和类型
messages = [
{"role": "user", "content": prompt}
]
response = await llm_client.chat(messages=messages)
# 3. 解析LLM响应
content_response = response.content
if isinstance(content_response, list):
if len(content_response) > 0:
if isinstance(content_response[0], dict):
text = content_response[0].get('text', content_response[0].get('content', str(content_response[0])))
full_response = str(text)
else:
full_response = str(content_response[0])
else:
full_response = ""
elif isinstance(content_response, dict):
full_response = str(content_response.get('text', content_response.get('content', str(content_response))))
else:
full_response = str(content_response) if content_response is not None else ""
# 4. 解析JSON响应
try:
# 尝试从响应中提取JSON
# 移除可能的markdown代码块标记
json_str = full_response.strip()
if json_str.startswith("```json"):
json_str = json_str[7:]
if json_str.startswith("```"):
json_str = json_str[3:]
if json_str.endswith("```"):
json_str = json_str[:-3]
json_str = json_str.strip()
result_data = json.loads(json_str)
title = result_data.get("title", "未知标题")
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
# 5. 校验和归一化类型
# 将类型转换为小写并去除空格
episodic_type_normalized = str(episodic_type_raw).lower().strip()
# 检查是否在有效类型集合中
if episodic_type_normalized in VALID_TYPES:
episodic_type = episodic_type_normalized
else:
# 尝试映射常见的中文类型到英文
type_mapping = {
"对话": "conversation",
"项目": "project_work",
"工作": "project_work",
"项目/工作": "project_work",
"学习": "learning",
"决策": "decision",
"重要事件": "important_event",
"事件": "important_event"
}
episodic_type = type_mapping.get(episodic_type_raw, DEFAULT_TYPE)
logger.warning(
f"LLM返回的类型 '{episodic_type_raw}' 不在有效集合中,"
f"已归一化为 '{episodic_type}'"
)
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
return (title, episodic_type)
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {full_response}")
return ("解析失败", DEFAULT_TYPE)
except Exception as e:
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
return ("错误", DEFAULT_TYPE)
async def _process_chunk_summary(
dialog: DialogData,
chunk,
@@ -172,27 +59,13 @@ async def _process_chunk_summary(
)
summary_text = structured.summary.strip()
# Generate title and type for the summary
title = None
episodic_type = None
try:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
except Exception as e:
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
# Continue without title and type
# Embed the summary
embedding = (await embedder.response([summary_text]))[0]
# Build node per chunk
# Note: title is stored in the 'name' field, type is stored in 'memory_type' field
node = MemorySummaryNode(
id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}",
name=f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id,
user_id=dialog.user_id,
apply_id=dialog.apply_id,
@@ -202,7 +75,6 @@ async def _process_chunk_summary(
dialog_id=dialog.id,
chunk_ids=[chunk.id],
content=summary_text,
memory_type=episodic_type,
summary_embedding=embedding,
metadata={"ref_id": dialog.ref_id},
config_id=dialog.config_id, # 添加 config_id

View File

@@ -1,40 +1,8 @@
"""遗忘引擎模块
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。
"""
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
calculate_activation,
generate_forgetting_curve
)
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
ConsistencyCheckResult
)
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import (
ForgettingStrategy
)
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import (
ForgettingScheduler
)
from app.core.memory.storage_services.forgetting_engine.config_utils import (
calculate_forgetting_rate,
load_actr_config_from_db,
create_actr_calculator_from_config
)
__all__ = [
"ForgettingEngine",
"ACTRCalculator",
"calculate_activation",
"generate_forgetting_curve",
"AccessHistoryManager",
"ConsistencyCheckResult",
"ForgettingStrategy",
"ForgettingScheduler",
"calculate_forgetting_rate",
"load_actr_config_from_db",
"create_actr_calculator_from_config"
]
__all__ = ["ForgettingEngine"]

View File

@@ -1,732 +0,0 @@
"""
访问历史管理器模块
本模块实现访问历史的追踪、更新和一致性保证。
负责在知识节点被访问时原子性地更新激活值相关的所有字段。
Classes:
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
"""
import asyncio
import logging
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class ConsistencyCheckResult(Enum):
"""一致性检查结果枚举"""
CONSISTENT = "consistent" # 数据一致
INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time
INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count
MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值
INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围
class AccessHistoryManager:
"""
访问历史管理器
负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段:
- activation_value: 激活值
- access_history: 访问历史时间戳数组
- last_access_time: 最后访问时间
- access_count: 访问次数
特性:
- 原子性更新使用Neo4j事务确保所有字段同时更新或回滚
- 并发安全:使用乐观锁机制防止并发冲突
- 一致性保证:提供一致性检查和自动修复功能
- 智能修剪:自动修剪过长的访问历史
Attributes:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
max_retries: int = 3
):
"""
初始化访问历史管理器
Args:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数默认3次
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.max_retries = max_retries
async def record_access(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
记录节点访问并原子性更新所有相关字段
这是核心方法,实现了:
1. 首次访问初始化access_history计算初始激活值
2. 后续访问:追加访问历史,重新计算激活值
3. 历史修剪:当历史过长时自动修剪
4. 原子性:所有字段在单个事务中更新
5. 并发安全:使用乐观锁重试机制
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
Dict[str, Any]: 更新后的节点数据,包含:
- id: 节点ID
- activation_value: 更新后的激活值
- access_history: 更新后的访问历史
- last_access_time: 最后访问时间
- access_count: 访问次数
- importance_score: 重要性分数
Raises:
ValueError: 如果节点不存在或节点标签无效
RuntimeError: 如果重试次数耗尽仍然失败
"""
if current_time is None:
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 验证节点标签
valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"]
if node_label not in valid_labels:
raise ValueError(
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
)
# 使用乐观锁重试机制处理并发冲突
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
raise ValueError(
f"Node not found: {node_label} with id={node_id}"
)
# 步骤2计算新的访问历史和激活值
update_data = await self._calculate_update(
node_data=node_data,
current_time=current_time,
current_time_iso=current_time_iso
)
# 步骤3原子性更新节点使用事务
updated_node = await self._atomic_update(
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
logger.info(
f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}"
)
return updated_node
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}"
)
continue
else:
logger.error(
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
f"错误: {str(e)}"
)
raise RuntimeError(
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
)
async def record_batch_access(
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
批量记录多个节点的访问
为提高性能,批量更新多个节点的访问历史。
每个节点独立更新,失败的节点不影响其他节点。
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
current_time: 当前时间(可选)
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
import time
batch_start = time.time()
if current_time is None:
current_time = datetime.now()
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
tasks = []
for node_id in node_ids:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
tasks.append(task)
# Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Collect successful results and count failures
results = []
failed_count = 0
for node_id, result in zip(node_ids, task_results):
if isinstance(result, Exception):
failed_count += 1
logger.warning(
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
)
else:
results.append(result)
batch_duration = time.time() - batch_start
logger.info(
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
)
return results
async def check_consistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
验证以下一致性规则:
1. access_history[-1] == last_access_time
2. len(access_history) == access_count
3. 如果有访问历史,必须有激活值
4. 激活值必须在有效范围内 [offset, 1.0]
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
access_history = node_data.get('access_history') or []
last_access_time = node_data.get('last_access_time')
access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value')
# 检查1access_history[-1] == last_access_time
if access_history and last_access_time:
if access_history[-1] != last_access_time:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME,
f"access_history[-1]={access_history[-1]} != "
f"last_access_time={last_access_time}"
)
# 检查2len(access_history) == access_count
if len(access_history) != access_count:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
f"len(access_history)={len(access_history)} != "
f"access_count={access_count}"
)
# 检查3有访问历史必须有激活值
if access_history and activation_value is None:
return (
ConsistencyCheckResult.MISSING_ACTIVATION,
"Node has access_history but activation_value is None"
)
# 检查4激活值范围
if activation_value is not None:
offset = self.actr_calculator.offset
if not (offset <= activation_value <= 1.0):
return (
ConsistencyCheckResult.INVALID_ACTIVATION_RANGE,
f"activation_value={activation_value} out of range "
f"[{offset}, 1.0]"
)
return ConsistencyCheckResult.CONSISTENT, None
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
批量检查多个节点的一致性
Args:
node_label: 节点标签
group_id: 组ID可选
limit: 检查的最大节点数
Returns:
Dict[str, Any]: 一致性检查报告,包含:
- total_checked: 检查的节点总数
- consistent_count: 一致的节点数
- inconsistent_count: 不一致的节点数
- inconsistencies: 不一致节点的详细信息列表
- consistency_rate: 一致性率0-1
"""
# 查询所有相关节点
query = f"""
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
# 检查每个节点
inconsistencies = []
consistent_count = 0
for node_id in node_ids:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
consistent_count += 1
else:
inconsistencies.append({
'node_id': node_id,
'result': result.value,
'message': message
})
total_checked = len(node_ids)
inconsistent_count = len(inconsistencies)
consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0
report = {
'total_checked': total_checked,
'consistent_count': consistent_count,
'inconsistent_count': inconsistent_count,
'inconsistencies': inconsistencies,
'consistency_rate': consistency_rate
}
logger.info(
f"一致性检查完成: {node_label}, "
f"一致率={consistency_rate:.2%}, "
f"不一致节点={inconsistent_count}/{total_checked}"
)
return report
async def repair_inconsistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
修复策略:
1. 如果access_history[-1] != last_access_time使用access_history[-1]
2. 如果len(access_history) != access_count使用len(access_history)
3. 如果有历史但无激活值:重新计算激活值
4. 如果激活值超出范围:重新计算激活值
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
"""
try:
# 检查一致性
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据
repair_data = {}
# 修复last_access_time
if access_history:
repair_data['last_access_time'] = access_history[-1]
# 修复access_count
repair_data['access_count'] = len(access_history)
# 修复activation_value
if access_history:
current_time = datetime.now()
last_access_dt = datetime.fromisoformat(access_history[-1])
access_history_dt = [
datetime.fromisoformat(ts) for ts in access_history
]
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=access_history_dt,
current_time=current_time,
last_access_time=last_access_dt,
importance_score=importance_score
)
repair_data['activation_value'] = activation_value
# 执行修复
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
SET n += $repair_data
RETURN n
"""
params = {
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
await self.connector.execute_query(query, **params)
logger.info(
f"成功修复节点不一致: {node_label}[{node_id}], "
f"问题类型={result.value}"
)
return True
except Exception as e:
logger.error(
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
)
return False
# ==================== 私有辅助方法 ====================
async def _fetch_node(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
"""
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]
return None
async def _calculate_update(
self,
node_data: Dict[str, Any],
current_time: datetime,
current_time_iso: str
) -> Dict[str, Any]:
"""
计算更新数据
Args:
node_data: 当前节点数据
current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串
Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history') or []
# Handle None importance_score - default to 0.5
importance_score = node_data.get('importance_score')
if importance_score is None:
importance_score = 0.5
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
# 修剪访问历史(如果过长)
access_history_dt = [
datetime.fromisoformat(ts) for ts in new_access_history
]
trimmed_history_dt = self.actr_calculator.trim_access_history(
access_history=access_history_dt,
current_time=current_time
)
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
# 计算新的激活值
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=trimmed_history_dt,
current_time=current_time,
last_access_time=current_time, # 最后访问时间就是当前时间
importance_score=importance_score
)
# 返回所有需要更新的字段
return {
'activation_value': activation_value,
'access_history': trimmed_history,
'last_access_time': current_time_iso,
'access_count': len(trimmed_history)
}
async def _atomic_update(
self,
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
实现乐观锁机制防止并发冲突。
Args:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
Raises:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
read_query += """
RETURN n.id as id,
n.version as version,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
if not current_node:
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
# 获取当前版本号如果不存在则为0
current_version = current_node.get('version', 0) or 0
new_version = current_version + 1
# 步骤2使用乐观锁更新节点
# 根据节点类型构建完整的查询语句
content_field_map = {
'Statement': 'n.statement as statement',
'MemorySummary': 'n.content as content',
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
}
# 显式检查节点类型,不支持的类型抛出错误
if node_label not in content_field_map:
raise ValueError(
f"Unsupported node_label: {node_label}. "
f"Supported labels are: {list(content_field_map.keys())}"
)
content_field = content_field_map[node_label]
# 构建 WHERE 子句
where_conditions = []
if group_id:
where_conditions.append("n.group_id = $group_id")
# 添加版本检查
if current_version > 0:
where_conditions.append("n.version = $current_version")
else:
where_conditions.append("(n.version IS NULL OR n.version = 0)")
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
# 构建完整的更新查询
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
WHERE {where_clause}
SET n.activation_value = $activation_value,
n.access_history = $access_history,
n.last_access_time = $last_access_time,
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score,
n.version as version,
{content_field}
"""
update_params = {
'node_id': node_id,
'current_version': current_version,
'new_version': new_version,
'activation_value': update_data['activation_value'],
'access_history': update_data['access_history'],
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
if not updated_node:
raise RuntimeError(
f"Version conflict detected for {node_label}[{node_id}]. "
f"Expected version {current_version}, but node was modified by another transaction."
)
# 转换为字典并移除占位符字段
result_dict = dict(updated_node)
result_dict.pop('content_placeholder', None)
return result_dict
# 执行事务
try:
result = await self.connector.execute_write_transaction(
update_transaction,
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
return result
except Exception as e:
logger.error(
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
)
raise RuntimeError(
f"Failed to atomically update node: {str(e)}"
) from e

View File

@@ -1,359 +0,0 @@
"""
ACT-R Memory Activation Calculator
This module implements the unified Memory Activation model based on ACT-R
(Adaptive Control of Thought-Rational) cognitive architecture theory.
The calculator integrates BLA (Base-Level Activation) computation into the
Memory Activation formula, providing a single coherent model for memory strength
calculation that reflects both recency and frequency of access.
Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
Where:
- R(i): Memory activation value (0 to 1)
- offset: Minimum retention rate (prevents complete forgetting)
- λ: Forgetting rate (lambda_time / lambda_mem)
- t: Time since last access
- I: Importance score (0 to 1)
- t_k: Time since k-th access
- d: Decay constant (typically 0.5)
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
"""
import math
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
class ACTRCalculator:
"""
Unified ACT-R Memory Activation Calculator.
This calculator implements the Memory Activation model that combines
recency and frequency effects into a single activation value computation.
It replaces the separate BLA calculation with an integrated approach.
Attributes:
decay_constant: Decay parameter d (typically 0.5)
forgetting_rate: Lambda parameter λ controlling forgetting speed
offset: Minimum retention rate (baseline memory strength)
max_history_length: Maximum number of access records to keep
"""
def __init__(
self,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1,
max_history_length: int = 100
):
"""
Initialize the ACT-R calculator.
Args:
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
max_history_length: Maximum access history length (default 100)
"""
self.decay_constant = decay_constant
self.forgetting_rate = forgetting_rate
self.offset = offset
self.max_history_length = max_history_length
def calculate_memory_activation(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate memory activation value using the unified Memory Activation formula.
This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
The formula integrates:
- Recency effect: Recent accesses contribute more (via t)
- Frequency effect: Multiple accesses strengthen memory (via Σ)
- Importance weighting: Important memories decay slower (via I)
Args:
access_history: List of access timestamps (ISO format or datetime objects)
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Memory activation value between offset and 1.0
Raises:
ValueError: If access_history is empty or contains invalid data
"""
if not access_history:
raise ValueError("access_history cannot be empty")
if not (0.0 <= importance_score <= 1.0):
raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}")
# Calculate time since last access (in days)
time_since_last = (current_time - last_access_time).total_seconds() / 86400.0
time_since_last = max(time_since_last, 0.0001) # Avoid division by zero
# Calculate BLA component: Σ(I·t_k^(-d))
bla_sum = 0.0
for access_time in access_history:
# Calculate time since this access (in days)
time_diff = (current_time - access_time).total_seconds() / 86400.0
time_diff = max(time_diff, 0.0001) # Avoid division by zero
# Add weighted power-law term: I * t_k^(-d)
bla_sum += importance_score * (time_diff ** (-self.decay_constant))
# Avoid division by zero in case of numerical issues
if bla_sum <= 0:
bla_sum = 0.0001
# Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA)
exponent = -self.forgetting_rate * time_since_last / bla_sum
# Clamp exponent to avoid numerical overflow/underflow
exponent = max(min(exponent, 100), -100)
activation = self.offset + (1 - self.offset) * math.exp(exponent)
# Ensure activation is within valid range [offset, 1.0]
return max(self.offset, min(1.0, activation))
def trim_access_history(
self,
access_history: List[datetime],
current_time: datetime
) -> List[datetime]:
"""
Intelligently trim access history to prevent unbounded growth.
Strategy:
- Keep all records if under max_history_length
- If over limit, keep most recent 50% and sample from older records
- Preserves both recent accesses (high importance) and historical pattern
Args:
access_history: List of access timestamps (sorted or unsorted)
current_time: Current time for calculation
Returns:
List[datetime]: Trimmed access history
"""
if len(access_history) <= self.max_history_length:
return access_history
# Sort by time (most recent first)
sorted_history = sorted(access_history, reverse=True)
# Calculate split point (keep most recent 50%)
keep_recent_count = self.max_history_length // 2
# Keep most recent 50%
recent_records = sorted_history[:keep_recent_count]
# Sample from older records
older_records = sorted_history[keep_recent_count:]
sample_count = self.max_history_length - keep_recent_count
if len(older_records) <= sample_count:
# If older records fit, keep them all
sampled_older = older_records
else:
# Sample evenly from older records
step = len(older_records) / sample_count
sampled_older = [
older_records[int(i * step)]
for i in range(sample_count)
]
# Combine and return
trimmed_history = recent_records + sampled_older
return sorted(trimmed_history, reverse=True)
def get_forgetting_curve( # 预测激活值决定复习测试不同配置效果选择合适的d
self,
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60
) -> List[Dict[str, Any]]:
"""
Generate forgetting curve data for visualization.
This method simulates how memory activation decays over time
for a single initial access, useful for understanding and
visualizing the forgetting behavior.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
Returns:
List of dictionaries with keys:
- 'day': Day number (0 to days)
- 'activation': Memory activation value
- 'retention_rate': Same as activation (for compatibility)
"""
curve_data = []
access_history = [initial_time]
for day in range(days + 1):
current_time = initial_time + timedelta(days=day)
try:
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=initial_time,
importance_score=importance_score
)
except ValueError:
# Handle edge cases
activation = self.offset
curve_data.append({
'day': day,
'activation': activation,
'retention_rate': activation # Alias for compatibility
})
return curve_data
def calculate_forgetting_score(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate forgetting score (inverse of activation).
Forgetting score = 1 - activation value
Higher score means more likely to be forgotten.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Forgetting score between 0 and (1 - offset)
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return 1.0 - activation
def should_forget(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
threshold: float = 0.3
) -> bool:
"""
Determine if a memory should be forgotten based on activation threshold.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
threshold: Activation threshold below which memory should be forgotten
Returns:
bool: True if activation < threshold (should forget), False otherwise
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return activation < threshold
# Convenience functions for quick calculations
def calculate_activation(
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> float:
"""
Quick function to calculate activation without creating a calculator instance.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
float: Memory activation value between offset and 1.0
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
def generate_forgetting_curve(
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> List[Dict[str, Any]]:
"""
Quick function to generate forgetting curve data.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
List of dictionaries with forgetting curve data
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.get_forgetting_curve(
initial_time=initial_time,
importance_score=importance_score,
days=days
)

View File

@@ -1,195 +0,0 @@
"""
遗忘引擎配置工具模块
本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。
Functions:
calculate_forgetting_rate: 计算遗忘速率lambda_time / lambda_mem
load_actr_config_from_db: 从数据库加载 ACT-R 配置参数
create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例
"""
import logging
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
"""
计算遗忘速率
公式forgetting_rate = lambda_time / lambda_mem
这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数,
用于 ACT-R 激活值计算。
Args:
lambda_time: 时间衰减参数0-1
lambda_mem: 记忆衰减参数0-1
Returns:
float: 遗忘速率
Raises:
ValueError: 如果 lambda_mem 为 0
Examples:
>>> calculate_forgetting_rate(0.5, 0.5)
1.0
>>> calculate_forgetting_rate(0.3, 0.5)
0.6
"""
if lambda_mem == 0:
raise ValueError("lambda_mem 不能为 0")
forgetting_rate = lambda_time / lambda_mem
logger.debug(
f"计算遗忘速率: lambda_time={lambda_time}, "
f"lambda_mem={lambda_mem}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return forgetting_rate
def load_actr_config_from_db(
db: Session,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数,
并计算派生参数(如 forgetting_rate
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
Dict[str, Any]: 配置参数字典,包含:
- decay_constant: 衰减常数 d
- lambda_time: 时间衰减参数
- lambda_mem: 记忆衰减参数
- forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出)
- offset: 偏移量
- max_history_length: 访问历史最大长度
- forgetting_threshold: 遗忘阈值
- min_days_since_access: 最小未访问天数
- enable_llm_summary: 是否使用 LLM 生成摘要
- max_merge_batch_size: 单次最大融合节点对数
- forgetting_interval_hours: 遗忘周期间隔
注意llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取
Raises:
ValueError: 如果指定的 config_id 不存在
"""
# 必须指定 config_id
if config_id is None:
logger.error("未指定 config_id无法加载配置")
raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID")
# 从数据库加载配置
try:
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None:
logger.error(f"配置不存在: config_id={config_id}")
raise ValueError(f"配置不存在: config_id={config_id}")
# 读取配置参数(信任数据库默认值)
lambda_time = db_config.lambda_time
lambda_mem = db_config.lambda_mem
decay_constant = db_config.decay_constant
offset = db_config.offset
max_history_length = db_config.max_history_length
forgetting_threshold = db_config.forgetting_threshold
min_days_since_access = db_config.min_days_since_access
enable_llm_summary = db_config.enable_llm_summary
max_merge_batch_size = db_config.max_merge_batch_size
forgetting_interval_hours = db_config.forgetting_interval_hours
# 计算 forgetting_rate
forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem)
config = {
'decay_constant': decay_constant,
'lambda_time': lambda_time,
'lambda_mem': lambda_mem,
'forgetting_rate': forgetting_rate,
'offset': offset,
'max_history_length': max_history_length,
'forgetting_threshold': forgetting_threshold,
'min_days_since_access': min_days_since_access,
'enable_llm_summary': enable_llm_summary,
'max_merge_batch_size': max_merge_batch_size,
'forgetting_interval_hours': forgetting_interval_hours
# 注意llm_id 不包含在配置响应中,仅在内部使用
}
logger.info(
f"成功加载 ACT-R 配置: config_id={config_id}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return config
except Exception as e:
logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}")
raise
def create_actr_calculator_from_config(
db: Session,
config_id: Optional[int] = None
) -> ACTRCalculator:
"""
从数据库配置创建 ACTRCalculator 实例
这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
ACTRCalculator: 配置好的 ACT-R 计算器实例
Raises:
ValueError: 如果指定的 config_id 不存在
Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)
# 创建计算器
calculator = ACTRCalculator(
decay_constant=config['decay_constant'],
forgetting_rate=config['forgetting_rate'],
offset=config['offset'],
max_history_length=config['max_history_length']
)
logger.info(
f"创建 ACTRCalculator: config_id={config_id}, "
f"decay_constant={config['decay_constant']}, "
f"forgetting_rate={config['forgetting_rate']:.4f}, "
f"offset={config['offset']}"
)
return calculator

View File

@@ -1,351 +0,0 @@
"""
遗忘调度器模块
本模块实现遗忘周期的调度和管理,负责:
1. 手动触发遗忘周期
2. 批量处理可遗忘节点(限制批量大小)
3. 按激活值优先级排序(激活值最低的优先)
4. 进度跟踪和日志记录
5. 生成遗忘报告
注意:定期调度功能已迁移到 Celery Beat见 app/tasks.py 中的 run_forgetting_cycle_task
Classes:
ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能
"""
import logging
from typing import Dict, Any, Optional
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class ForgettingScheduler:
"""
遗忘调度器
管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。
核心功能:
1. 运行遗忘周期:识别可遗忘节点并批量融合
2. 优先级排序:优先处理激活值最低的节点对
3. 批量限制:限制单次处理的节点对数量
4. 进度跟踪:每完成 10% 记录一次日志
5. 遗忘报告:生成详细的执行报告
注意:定期调度功能已迁移到 Celery Beat 定时任务
Attributes:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
is_running: 是否正在运行遗忘周期
"""
def __init__(
self,
forgetting_strategy: ForgettingStrategy,
connector: Neo4jConnector
):
"""
初始化遗忘调度器
Args:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
"""
self.forgetting_strategy = forgetting_strategy
self.connector = connector
self.is_running = False
logger.info("初始化遗忘调度器")
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
db = None
) -> Dict[str, Any]:
"""
运行一次完整的遗忘周期
Args:
group_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
Dict[str, Any]: 遗忘报告,包含:
- merged_count: 融合的节点对数量
- nodes_before: 遗忘前的节点总数
- nodes_after: 遗忘后的节点总数
- reduction_rate: 节点减少率0-1
- duration_seconds: 执行耗时(秒)
- start_time: 开始时间ISO 格式)
- end_time: 结束时间ISO 格式)
- failed_count: 失败的融合数量
- success_rate: 成功率0-1
Raises:
RuntimeError: 如果已有遗忘周期正在运行
"""
# 检查是否已有遗忘周期在运行
if self.is_running:
raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成")
self.is_running = True
start_time = datetime.now()
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
min_days_since_access=min_days_since_access
)
total_forgettable = len(forgettable_pairs)
logger.info(f"识别到 {total_forgettable} 个可遗忘节点对")
if total_forgettable == 0:
logger.info("没有可遗忘的节点对,遗忘周期结束")
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
report = {
'merged_count': 0,
'nodes_before': nodes_before,
'nodes_after': nodes_before,
'reduction_rate': 0.0,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': 0,
'success_rate': 1.0
}
logger.info("没有可遗忘的节点对,遗忘周期结束")
return report
# 步骤3按激活值排序激活值最低的优先
# avg_activation 已经在 find_forgettable_nodes 中计算并排序
# 这里只需要确认排序是正确的(升序)
sorted_pairs = sorted(
forgettable_pairs,
key=lambda x: x['avg_activation']
)
# 步骤4限制批量大小
pairs_to_process = sorted_pairs[:max_merge_batch_size]
actual_batch_size = len(pairs_to_process)
logger.info(
f"将处理 {actual_batch_size} 个节点对 "
f"(限制: {max_merge_batch_size})"
)
# 步骤5批量融合节点每 10% 记录进度
merged_count = 0
failed_count = 0
skipped_count = 0 # 跳过的节点对数量(节点已被处理)
progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次
# 跟踪已处理的节点 ID避免重复处理
processed_statement_ids = set()
processed_entity_ids = set()
# 预先过滤掉重复的节点对
unique_pairs = []
for pair in pairs_to_process:
statement_id = pair['statement_id']
entity_id = pair['entity_id']
# 如果节点已被标记为处理,跳过
if statement_id in processed_statement_ids or entity_id in processed_entity_ids:
skipped_count += 1
logger.debug(
f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]"
)
continue
# 标记节点为已处理
processed_statement_ids.add(statement_id)
processed_entity_ids.add(entity_id)
unique_pairs.append(pair)
logger.info(
f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对,"
f"跳过 {skipped_count} 对重复节点"
)
# 更新实际处理的批次大小
actual_batch_size = len(unique_pairs)
progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔
for idx, pair in enumerate(unique_pairs, start=1):
statement_id = pair['statement_id']
entity_id = pair['entity_id']
try:
# 准备节点数据
statement_node = {
'statement_id': statement_id,
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
}
entity_node = {
'entity_id': entity_id,
'entity_name': pair['entity_name'],
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
}
# 融合节点
await self.forgetting_strategy.merge_nodes_to_summary(
statement_node=statement_node,
entity_node=entity_node,
config_id=config_id,
db=db
)
merged_count += 1
# 进度跟踪:每 10% 记录一次
if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size):
progress_pct = (idx / actual_batch_size) * 100
logger.info(
f"遗忘进度: {idx}/{actual_batch_size} "
f"({progress_pct:.1f}%), "
f"已融合: {merged_count}, 失败: {failed_count}"
)
except Exception as e:
failed_count += 1
# 检查是否是节点不存在的错误
if "nodes may not exist" in str(e):
logger.warning(
f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): "
f"Statement[{statement_id}] + Entity[{entity_id}]"
)
else:
logger.error(
f"融合节点对失败 ({idx}/{actual_batch_size}): "
f"Statement[{statement_id}] + Entity[{entity_id}], "
f"错误: {str(e)}"
)
# 继续处理剩余节点
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
# 计算节点减少率
if nodes_before > 0:
reduction_rate = (nodes_before - nodes_after) / nodes_before
else:
reduction_rate = 0.0
# 计算成功率
if actual_batch_size > 0:
success_rate = merged_count / actual_batch_size
else:
success_rate = 1.0
report = {
'merged_count': merged_count,
'nodes_before': nodes_before,
'nodes_after': nodes_after,
'reduction_rate': reduction_rate,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': failed_count,
'success_rate': success_rate
}
logger.info(
f"遗忘周期完成: "
f"融合 {merged_count} 对节点, "
f"失败 {failed_count} 对, "
f"节点减少 {nodes_before - nodes_after}"
f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}"
)
return report
except Exception as e:
logger.error(f"遗忘周期执行失败: {str(e)}")
raise
finally:
self.is_running = False
# ==================== 私有辅助方法 ====================
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
"""
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]['total']
return 0

View File

@@ -1,643 +0,0 @@
"""
遗忘策略执行器模块
本模块实现基于 ACT-R 激活值的遗忘策略,负责:
1. 识别低激活值的节点对Statement-Entity
2. 将低激活值节点融合为 MemorySummary 节点
3. 使用 LLM 生成高质量摘要(可选)
4. 保留溯源信息并删除原始节点
Classes:
ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能
"""
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
class ForgettingStrategy:
"""
遗忘策略执行器
基于 ACT-R 激活值识别和融合低价值记忆节点。
实现了完整的遗忘周期:识别 → 融合 → 删除。
核心功能:
1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对
2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性
3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接)
4. 溯源保留:记录原始节点 ID保持可追溯性
Attributes:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘)
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
forgetting_threshold: float = 0.3,
enable_llm_summary: bool = True
):
"""
初始化遗忘策略执行器
Args:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(默认 0.3
enable_llm_summary: 是否启用 LLM 摘要生成(默认 True
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.forgetting_threshold = forgetting_threshold
self.enable_llm_summary = enable_llm_summary
logger.info(
f"初始化遗忘策略执行器: threshold={forgetting_threshold}, "
f"enable_llm_summary={enable_llm_summary}"
)
async def calculate_forgetting_score(
self,
activation_value: float
) -> float:
"""
计算遗忘分数
遗忘分数 = 1 - 激活值
分数越高,越容易被遗忘。
注意:激活值已经包含了 importance_score 的权重,
因此不需要单独考虑重要性分数。
Args:
activation_value: 节点的激活值0-1
Returns:
float: 遗忘分数0-1值越高越容易被遗忘
"""
return 1.0 - activation_value
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
识别可遗忘的节点对
查找满足以下条件的 Statement-Entity 节点对:
1. 两个节点的激活值都低于遗忘阈值
2. 两个节点都至少 min_days_since_access 天未被访问
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含:
- statement_id: Statement 节点 ID
- statement_text: Statement 文本内容
- statement_activation: Statement 激活值
- statement_importance: Statement 重要性分数
- statement_last_access: Statement 最后访问时间
- entity_id: Entity 节点 ID
- entity_name: Entity 名称
- entity_type: Entity 类型
- entity_activation: Entity 激活值
- entity_importance: Entity 重要性分数
- entity_last_access: Entity 最后访问时间
- avg_activation: 平均激活值(用于排序)
"""
# 计算时间阈值
cutoff_time = datetime.now() - timedelta(days=min_days_since_access)
cutoff_time_iso = cutoff_time.isoformat()
# 构建查询
query = """
MATCH (s:Statement)-[r]-(e:ExtractedEntity)
WHERE s.activation_value IS NOT NULL
AND e.activation_value IS NOT NULL
AND s.activation_value < $threshold
AND e.activation_value < $threshold
AND s.last_access_time < $cutoff_time
AND e.last_access_time < $cutoff_time
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
query += """
RETURN s.id as statement_id,
s.statement as statement_text,
s.activation_value as statement_activation,
s.importance_score as statement_importance,
s.last_access_time as statement_last_access,
e.id as entity_id,
e.name as entity_name,
e.entity_type as entity_type,
e.activation_value as entity_activation,
e.importance_score as entity_importance,
e.last_access_time as entity_last_access,
(s.activation_value + e.activation_value) / 2.0 as avg_activation
ORDER BY avg_activation ASC
"""
params = {
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
logger.info(
f"识别到 {len(results)} 个可遗忘节点对 "
f"(threshold={self.forgetting_threshold}, "
f"min_days={min_days_since_access})"
)
return results
async def merge_nodes_to_summary(
self,
statement_node: Dict[str, Any],
entity_node: Dict[str, Any],
config_id: Optional[int] = None,
db = None
) -> str:
"""
将 Statement 和 Entity 节点融合为 MemorySummary 节点
融合过程:
1. 生成摘要内容(使用 LLM 或简单拼接)
2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数
3. 删除原始 Statement 和 Entity 节点
4. 保留溯源信息original_statement_id, original_entity_id
Args:
statement_node: Statement 节点数据,必须包含:
- statement_id: 节点 ID
- statement_text: 文本内容
- statement_activation: 激活值
- statement_importance: 重要性分数
entity_node: Entity 节点数据,必须包含:
- entity_id: 节点 ID
- entity_name: 实体名称
- entity_type: 实体类型
- entity_activation: 激活值
- entity_importance: 重要性分数
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 创建的 MemorySummary 节点 ID
Raises:
ValueError: 如果节点数据不完整
RuntimeError: 如果融合操作失败
"""
# 验证输入数据
required_statement_keys = [
'statement_id', 'statement_text',
'statement_activation', 'statement_importance'
]
required_entity_keys = [
'entity_id', 'entity_name', 'entity_type',
'entity_activation', 'entity_importance'
]
for key in required_statement_keys:
if key not in statement_node:
raise ValueError(f"Statement 节点缺少必需字段: {key}")
for key in required_entity_keys:
if key not in entity_node:
raise ValueError(f"Entity 节点缺少必需字段: {key}")
# 验证实体类型:不允许融合 Person 类型的实体
if entity_node.get('entity_type') == 'Person':
raise ValueError(
f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, "
f"entity_name={entity_node.get('entity_name')}"
)
# 提取节点信息
statement_id = statement_node['statement_id']
statement_text = statement_node['statement_text']
statement_activation = statement_node['statement_activation']
statement_importance = statement_node['statement_importance']
entity_id = entity_node['entity_id']
entity_name = entity_node['entity_name']
entity_type = entity_node['entity_type']
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 生成摘要内容
summary_text = await self._generate_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
config_id=config_id,
db=db
)
# 生成标题和类型使用LLM
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import generate_title_and_type_for_summary
# 获取 LLM 客户端
llm_client = None
if config_id is not None and db is not None:
try:
llm_client = await self._get_llm_client(db, config_id)
except Exception as e:
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
# 生成标题和类型
try:
if llm_client is not None:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
)
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
else:
logger.warning("LLM 客户端不可用,使用默认标题和类型")
title = "未命名"
episodic_type = "conversation"
except Exception as e:
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
title = "未命名"
episodic_type = "conversation"
# 计算继承的激活值和重要性(取较高值)
inherited_activation = max(statement_activation, entity_activation)
inherited_importance = max(statement_importance, entity_importance)
# 创建 MemorySummary 节点
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 生成新的 MemorySummary ID
import uuid
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
# 使用事务创建 MemorySummary 并删除原节点
async def merge_transaction(tx, **params):
"""事务函数:创建摘要节点并删除原节点"""
query = """
// 首先检查节点是否存在
OPTIONAL MATCH (s:Statement {id: $statement_id})
OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id})
// 如果任一节点不存在,直接返回 null不执行后续操作
WITH s, e
WHERE s IS NOT NULL AND e IS NOT NULL
// 创建 MemorySummary 节点
CREATE (ms:MemorySummary {
id: $summary_id,
summary: $summary_text,
name: $title,
memory_type: $episodic_type,
original_statement_id: $statement_id,
original_entity_id: $entity_id,
activation_value: $inherited_activation,
importance_score: $inherited_importance,
access_history: [$current_time],
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
// 转移 Statement 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (s)-[r_out]->(target)
WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Statement 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(s)
WHERE r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (e)-[r_out]->(target)
WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(e)
WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 删除原始节点
WITH ms, s, e
DETACH DELETE s, e
RETURN ms.id as summary_id
"""
result = await tx.run(query, **params)
record = await result.single()
if not record:
raise RuntimeError("Failed to create MemorySummary node - nodes may not exist")
return record['summary_id']
params = {
'summary_id': summary_id,
'summary_text': summary_text,
'title': title,
'episodic_type': episodic_type,
'statement_id': statement_id,
'entity_id': entity_id,
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
}
try:
created_summary_id = await self.connector.execute_write_transaction(
merge_transaction,
**params
)
logger.info(
f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] "
f"-> MemorySummary[{created_summary_id}], "
f"activation={inherited_activation:.4f}, "
f"importance={inherited_importance:.4f}"
)
return created_summary_id
except Exception as e:
# 记录详细的错误信息,包括异常类型和堆栈
import traceback
error_details = traceback.format_exc()
logger.error(
f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], "
f"错误类型: {type(e).__name__}, "
f"错误信息: {str(e)}, "
f"详细堆栈:\n{error_details}"
)
raise RuntimeError(
f"融合节点失败: {str(e)}"
) from e
# ==================== 私有辅助方法 ====================
async def _generate_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
config_id: Optional[int] = None,
db = None
) -> str:
"""
生成摘要内容
优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败,
则降级到简单文本拼接。
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 生成的摘要文本(最多 200 个字符)
"""
# 如果配置禁用 LLM 摘要,直接使用简单拼接
if not self.enable_llm_summary:
logger.info("LLM 摘要生成已禁用,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试获取 LLM 客户端
llm_client = None
if config_id is not None and db is not None:
try:
llm_client = await self._get_llm_client(db, config_id)
except Exception as e:
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
# 如果没有 LLM 客户端,直接使用简单拼接
if llm_client is None:
logger.info("未能获取 LLM 客户端,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试使用 LLM 生成摘要
try:
summary = await self._generate_llm_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
llm_client=llm_client
)
# 限制长度为 200 个字符
if len(summary) > 200:
summary = f"{summary[:197]}..."
logger.info(f"使用 LLM 生成摘要: {summary}")
return summary
except Exception as e:
logger.warning(
f"LLM 摘要生成失败,降级到简单拼接: {str(e)}"
)
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
async def _get_llm_client(self, db, config_id: int):
"""
从数据库获取 LLM 客户端
Args:
db: 数据库会话
config_id: 配置ID
Returns:
LLM 客户端实例,如果无法获取则返回 None
"""
try:
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None:
logger.warning(f"配置 {config_id} 不存在或未设置 llm_id")
return None
# 创建 LLM 客户端
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(db_config.llm_id))
logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}")
return llm_client
except Exception as e:
logger.error(f"获取 LLM 客户端失败: {str(e)}")
return None
async def _generate_llm_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
llm_client
) -> str:
"""
使用 LLM 生成高质量摘要
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
llm_client: LLM 客户端实例
Returns:
str: LLM 生成的摘要文本
Raises:
Exception: 如果 LLM 调用失败
"""
# 构建提示词
prompt = f"""请为以下记忆片段生成一个简洁的摘要不超过200个字符
实体名称: {entity_name}
实体类型: {entity_type}
陈述内容: {statement_text}
要求:
1. 摘要应该保留核心语义信息
2. 长度不超过200个字符
3. 使用简洁、自然的中文表达
4. 只返回摘要文本,不要包含其他内容
摘要:"""
# 调用 LLM直接传递 prompt 字符串)
response = await llm_client.chat(prompt)
# 提取摘要文本
if isinstance(response, str):
summary = response.strip()
elif hasattr(response, 'content'):
summary = response.content.strip()
else:
summary = str(response).strip()
return summary
def _simple_concatenation(
self,
statement_text: str,
entity_name: str,
entity_type: str
) -> str:
"""
简单文本拼接生成摘要
降级策略:当 LLM 不可用时使用。
格式:[实体类型]实体名称: 陈述内容
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
Returns:
str: 拼接的摘要文本(最多 200 个字符)
"""
# 构建简单摘要
summary = f"[{entity_type}]{entity_name}: {statement_text}"
# 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数)
if len(summary) > 200:
# 截断并添加省略号
summary = f"{summary[:197]}..."
return summary

View File

@@ -2,39 +2,52 @@
"memory_verify": {
"source_data": [
{
"statement_name": "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合。"
"statement_name": "用户是2023年春天去北京工作的。",
"statement_id": "62beac695b1346f4871740a45db88782"
},
{
"statement_name": "用户后来基本一直都在北京上班。"
"statement_name": "用户后来基本一直都在北京上班。",
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b"
},
{
"statement_name": "用户从2023年开始就一直在北京生活。"
"statement_name": "用户从2023年开始就一直在北京生活。",
"statement_id": "e612a44da4db483993c350df7c97a1a1"
},
{
"statement_name": "用户从来没有长期离开过北京。"
"statement_name": "用户从来没有长期离开过北京。",
"statement_id": "b3c787a2e33c49f7981accabbbb4538a"
},
{
"statement_name": "由于公司调整用户在2024年上半年被调到上海待了差不多半年。"
"statement_name": "由于公司调整用户在2024年上半年被调到上海待了差不多半年。",
"statement_id": "64cde4230cb24a4da726e7db9e7aa616"
},
{
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。"
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669"
},
{
"statement_name": "用户在入职时使用的身份信息是之前的身份证号为11010119950308123X。"
"statement_name": "用户在入职时使用的身份信息是之前的身份证号为11010119950308123X。",
"statement_id": "030afd362e9b4110b139e68e5d3e7143"
},
{
"statement_name": "用户的银行卡号是6222023847595898。"
"statement_name": "用户的银行卡号是6222023847595898。",
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f"
},
{
"statement_name": "用户的身份信息和银行卡信息一直没变。"
"statement_name": "用户的身份信息和银行卡信息一直没变。",
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f"
},
{
"statement_name": "用户认为在上海的那段时间更多算是远程配合。"
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
"statement_id": "150af89d2c154e6eb41ff1a91e37f962"
}
],
"databasets": [
{
"entity1_name": "Person",
"description": "表示人类个体的通用类型",
"statement_id": "62beac695b1346f4871740a45db88782",
"entity2_name": "用户",
"entity2": {
"description": "叙述者,讲述个人工作与生活经历的个体",
"name": "用户"
@@ -42,6 +55,9 @@
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"entity2_name": "身份信息",
"entity2": {
"description": "用于个人身份识别的数据",
"name": "身份信息"
@@ -49,6 +65,9 @@
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"entity2_name": "6222023847595898",
"entity2": {
"description": "用户的银行卡号码",
"name": "6222023847595898"
@@ -57,24 +76,33 @@
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"entity2_name": "上海办公室",
"entity2": {
"entity_idx": 1,
"aliases": ["上海办"],
"description": "位于上海的工作办公场所",
"name": "上海办公室"
}
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"entity2_name": "北京",
"entity2": {
"aliases": ["京", "京城", "北平"],
"description": "中国的首都城市,用户主要工作和生活所在地",
"name": "北京"
}
},
{
"entity1_name": "11010119950308123X",
"description": "具体的身份证号码值",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"entity2_name": "身份证号",
"entity2": {
"description": "中华人民共和国公民的身份号码",
"name": "身份证号"
}
}

View File

@@ -387,7 +387,7 @@ class ReflectionEngine:
result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments
conflicts_found=''
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
# Clearn conflict_dataAnd memory_verify和quality_assessment
cleaned_conflict_data = []
for item in conflict_data:
@@ -396,23 +396,7 @@ class ReflectionEngine:
'conflict': item['conflict']
}
cleaned_conflict_data.append(cleaned_item)
cleaned_conflict_data_=[]
for item in conflict_data:
cleaned_data = []
for row in item.get("data", []):
# 删除 created_at / expired_at
cleaned_row = {
k: v
for k, v in row.items()
if k not in REMOVE_KEYS
}
cleaned_data.append(cleaned_row)
cleaned_item = {
"data": cleaned_data,
"conflict": item.get("conflict"),
}
cleaned_conflict_data_.append(cleaned_item)
print(cleaned_conflict_data_)
# 3. 解决冲突
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
if not solved_data:

View File

@@ -316,96 +316,3 @@ async def render_emotion_suggestions_prompt(
})
return rendered_prompt
async def render_user_summary_prompt(
user_id: str,
entities: str,
statements: str
) -> str:
"""
Renders the user summary prompt using the user_summary.jinja2 template.
Args:
user_id: User identifier
entities: Core entities with frequency information
statements: Representative statement samples
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("user_summary.jinja2")
rendered_prompt = template.render(
user_id=user_id,
entities=entities,
statements=statements
)
# 记录渲染结果到提示日志
log_prompt_rendering('user summary', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('user_summary.jinja2', {
'user_id': user_id,
'entities_len': len(entities),
'statements_len': len(statements)
})
return rendered_prompt
async def render_memory_insight_prompt(
domain_distribution: str = None,
active_periods: str = None,
social_connections: str = None
) -> str:
"""
Renders the memory insight prompt using the memory_insight.jinja2 template.
Args:
domain_distribution: 核心领域分布信息
active_periods: 活跃时段信息
social_connections: 社交关联信息
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("memory_insight.jinja2")
rendered_prompt = template.render(
domain_distribution=domain_distribution,
active_periods=active_periods,
social_connections=social_connections
)
# 记录渲染结果到提示日志
log_prompt_rendering('memory insight', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('memory_insight.jinja2', {
'has_domain_distribution': bool(domain_distribution),
'has_active_periods': bool(active_periods),
'has_social_connections': bool(social_connections)
})
return rendered_prompt
async def render_episodic_title_and_type_prompt(content: str) -> str:
"""
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
Args:
content: The content of the episodic memory summary to analyze
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("episodic_type_classification.jinja2")
rendered_prompt = template.render(content=content)
# 记录渲染结果到提示日志
log_prompt_rendering('episodic title and type classification', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('episodic_type_classification.jinja2', {
'content_len': len(content) if content else 0
})
return rendered_prompt

View File

@@ -1,57 +0,0 @@
=== Task ===
Generate a concise title and classify the episodic memory into the most appropriate category.
=== Requirements ===
- Extract a clear, concise title (10-20 characters) that captures the core content
- Classify into exactly one category based on the primary theme
- Be specific and avoid ambiguity
- Output must be valid JSON conforming to the schema below
=== Input ===
{{ content }}
=== Category Definitions ===
1. **conversation**: Daily communication, chat, discussion, and social interactions
- Keywords: chat, communication, discussion, dialogue, exchange
2. **project_work**: Work-related tasks, projects, meetings, and collaboration
- Keywords: project, task, work, meeting, collaboration, business, client
3. **learning**: Acquiring new knowledge, skill development, reading, and research
- Keywords: learning, reading, research, knowledge, skill, course, training
4. **decision**: Making important decisions, choices, and planning
- Keywords: decision, choice, planning, consideration, evaluation, weighing
5. **important_event**: Major events, milestones, and special experiences
- Keywords: important, major, milestone, special, memorable, celebration
=== Analysis Steps ===
1. Read the episodic memory content carefully
2. Identify the core theme and context
3. Extract a concise title
4. Compare against category definitions and keywords
5. Select the best matching category
6. If multiple categories apply, choose the primary one
=== Output Schema ===
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure
2. Escape any quotation marks within string values using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
Return only a JSON object with title and type fields:
{
"title": "Generated title here",
"type": "Category type here"
}
The type field must be exactly one of:
- conversation
- project_work
- learning
- decision
- important_event

View File

@@ -86,5 +86,5 @@
- **quality_assessment**:
quality_assessment=true时输出评估对象否则为null注意- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
- **memory_verify**: memory_verify=true时输出隐私检测对象否则为null
(注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容)
(注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
模式参考:{{ json_schema }}

View File

@@ -12,34 +12,7 @@ Extract entities and knowledge triplets from the given statement.
===Guidelines===
**Entity Extraction:**
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
- **Semantic Memory Classification (is_explicit_memory):**
* Set to `true` if the entity represents **explicit/semantic memory**:
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
- **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论"
- **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口"
- **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律"
- **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学"
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法"
- **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库"
* Set to `false` for:
- **People:** "John Smith", "Dr. Wang", "张明", "王博士"
- **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学"
- **Locations:** "Beijing", "Central Park", "北京", "中央公园"
- **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议"
- **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋"
- **Example Generation (IMPORTANT for semantic memory entities):**
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
* The example should be:
- **Specific and concrete**: Use real-world scenarios or applications
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
- **In the same language as the entity name**
* Examples:
- Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗"
- Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed"
- Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy"
- Entity: "人工智能" → example: "如:智能客服、自动驾驶"
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
- Extract entities with their types, context-independent descriptions, and aliases
- **Aliases Extraction (Important):**
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
* **DO NOT translate or add aliases in different languages**
@@ -111,27 +84,21 @@ Output:
"name": "I",
"type": "Person",
"description": "The user",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 1,
"name": "Paris",
"type": "Location",
"description": "Capital city of France",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 2,
"name": "Louvre",
"type": "Location",
"description": "World-famous museum located in Paris",
"example": "",
"aliases": ["Louvre Museum"],
"is_explicit_memory": false
"aliases": ["Louvre Museum"]
}
]
}
@@ -163,27 +130,21 @@ Output:
"name": "John Smith",
"type": "Person",
"description": "Individual person name",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 1,
"name": "Google",
"type": "Organization",
"description": "American technology company",
"example": "",
"aliases": ["Google LLC", "Alphabet Inc."],
"is_explicit_memory": false
"aliases": ["Google LLC", "Alphabet Inc."]
},
{
"entity_idx": 2,
"name": "AI product development",
"type": "Concept",
"type": "WorkRole",
"description": "Artificial intelligence product development work",
"example": "e.g., developing chatbots, recommendation systems",
"aliases": [],
"is_explicit_memory": true
"aliases": []
}
]
}
@@ -215,27 +176,21 @@ Output:
"name": "我",
"type": "Person",
"description": "用户本人",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 1,
"name": "巴黎",
"type": "Location",
"description": "法国首都城市",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 2,
"name": "卢浮宫",
"type": "Location",
"description": "位于巴黎的世界著名博物馆",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
}
]
}
@@ -267,27 +222,21 @@ Output:
"name": "张明",
"type": "Person",
"description": "个人姓名",
"example": "",
"aliases": [],
"is_explicit_memory": false
"aliases": []
},
{
"entity_idx": 1,
"name": "腾讯",
"type": "Organization",
"description": "中国科技公司",
"example": "",
"aliases": ["腾讯控股", "腾讯公司"],
"is_explicit_memory": false
"aliases": ["腾讯控股", "腾讯公司"]
},
{
"entity_idx": 2,
"name": "AI产品开发",
"type": "Concept",
"type": "WorkRole",
"description": "人工智能产品研发工作",
"example": "如:开发智能客服机器人、推荐系统",
"aliases": [],
"is_explicit_memory": true
"aliases": []
}
]
}
@@ -302,9 +251,7 @@ Output:
"name": "Tripod",
"type": "Equipment",
"description": "Photography equipment accessory",
"example": "",
"aliases": ["Camera Tripod"],
"is_explicit_memory": false
"aliases": ["Camera Tripod"]
}
]
}
@@ -319,9 +266,7 @@ Output:
"name": "三脚架",
"type": "Equipment",
"description": "摄影器材配件",
"example": "",
"aliases": ["相机三脚架"],
"is_explicit_memory": false
"aliases": ["相机三脚架"]
}
]
}

View File

@@ -1,152 +0,0 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
===Task===
Your task is to generate a comprehensive memory insight report based on the provided data analysis. The report should include four distinct sections that capture different aspects of the user's memory patterns and characteristics.
===Inputs===
{% if domain_distribution %}
- 核心领域分布: {{ domain_distribution }}
{% endif %}
{% if active_periods %}
- 活跃时段: {{ active_periods }}
{% endif %}
{% if social_connections %}
- 社交关联: {{ social_connections }}
{% endif %}
===Report Generation Requirements===
**General Guidelines:**
1. Base your analysis ONLY on the provided data - do not speculate or fabricate information
2. Use objective third-person descriptions with a professional and analytical tone
3. Avoid excessive adjectives and empty phrases
4. Strictly follow the output format specified below
5. If a dimension lacks data, skip that section or provide a brief note
**Section-Specific Requirements:**
1. **总体概述 (Overview)** (100-150 Chinese characters)
- Focus on: Overall analysis of user profile based on interaction logs
- Describe the user's main role, work network, and collaboration spirit
- Use professional, data-driven language style
- Example reference: "通过对156次交互日志的深度分析系统发现三层一位主要用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神。"
2. **行为模式 (Behavior Pattern)** (80-120 Chinese characters)
- Focus on: Work patterns, time regularity, and behavioral characteristics
- Describe weekly work patterns and time preferences
- Use objective, analytical language
- Example reference: "张三的工作模式呈现出鲜明的周期性:周一通常用于规划和会议,周三周四专注于产品设计和用户研究,周五进行总结和复盘。他倾向于在上午进行头脑风暴,下午处理执行性工作。"
3. **关键发现 (Key Findings)** (3-4 bullet points, 30-50 characters each)
- Focus on: Specific, insightful observations about user behavior and preferences
- Use bullet points (•) format
- Each finding should be concrete and data-supported
- Example reference:
"• 在产品决策中张三总是优先考虑用户反应这在68%的决策记录中得到体现
• 他善于使用数据可视化工具来支持论点,这种习惯在项目管理中发挥了重要作用
• 团队成员对他的评价中,"思路清晰"和"思路敏捷"两个关键词出现频率最高
• 他对AI机器学习领域保持持续关注近3个月参加了7次相关培训"
4. **成长轨迹 (Growth Trajectory)** (100-150 Chinese characters)
- Focus on: User's growth journey, key milestones, and capability improvements
- Organize content chronologically
- Highlight role changes and achievements
- Use positive, encouraging tone
- Example reference: "从入职时的产品经理成长为高级产品经理,张三在产品单独、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他的成长轨迹显示出对新技术的持续学习和对产品思维的不断深化。"
===Output Format (MUST STRICTLY FOLLOW)===
【总体概述】
[100-150 characters describing overall user profile and work network based on interaction analysis]
【行为模式】
[80-120 characters describing work patterns, time regularity, and behavioral characteristics]
【关键发现】
• [First key finding with data support, 30-50 characters]
• [Second key finding with data support, 30-50 characters]
• [Third key finding with data support, 30-50 characters]
• [Fourth key finding with data support, 30-50 characters]
【成长轨迹】
[100-150 characters describing growth journey, milestones, and capability improvements]
===Example===
Example Input:
- 核心领域分布: 产品管理(38%), 数据分析(24%), 团队协作(21%)
- 活跃时段: 用户在每年的 4 和 10 月最为活跃
- 社交关联: 与用户"李明"拥有最多共同记忆(47条),时间范围主要在 2020-2023
Example Output:
【总体概述】
通过对156次交互日志的深度分析系统发现张三是一位主要从事用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神在产品管理、数据分析和团队协作三个领域都有深入的实践。
【行为模式】
张三的工作模式呈现出鲜明的周期性周一通常用于规划和会议周三周四专注于产品设计和用户研究周五进行总结和复盘。他倾向于在上午进行头脑风暴下午处理执行性工作。每年4月和10月是他最活跃的时期。
【关键发现】
• 在产品决策中张三总是优先考虑用户反应这在68%的决策记录中得到体现
• 他善于使用数据可视化工具来支持论点,这种习惯在项目管理中发挥了重要作用
• 团队成员对他的评价中,"思路清晰"和"思路敏捷"两个关键词出现频率最高
• 他对AI机器学习领域保持持续关注近3个月参加了7次相关培训
【成长轨迹】
从入职时的产品经理成长为高级产品经理张三在产品规划、团队管理和技术理解三个方面都有显著提升。特别是在最近一年他开始独立主导更复杂的项目展现出更强的战略思维能力。他与李明的47条共同记忆见证了他的成长历程。
===End of Example===
===Reflection Process===
After generating the report, perform the following self-review steps:
**Step 1: Data Grounding Check**
- Verify all statements are supported by the provided data
- Ensure no fabricated or speculated information is included
- Confirm all claims can be traced back to the input data
**Step 2: Format Compliance**
- Verify each section follows the specified format with section headers
- Check character count limits for each section
- Ensure proper use of section markers (【】)
- Verify bullet points format for Key Findings section
**Step 3: Tone and Style Review**
- Confirm objective third-person perspective is maintained
- Check for excessive adjectives or empty phrases
- Verify professional and analytical tone throughout
**Step 4: Completeness Check**
- Ensure all four sections are present and complete
- Verify each section addresses its specific focus area
- Confirm the report provides actionable insights
===Output Requirements===
**LANGUAGE REQUIREMENT:**
- The output language should ALWAYS be Chinese (Simplified)
- All section content must be in Chinese
- Section headers must use the specified Chinese format: 【总体概述】【行为模式】【关键发现】【成长轨迹】
**FORMAT REQUIREMENT:**
- Each section must start with its header on a new line
- Content follows immediately after the header
- Sections are separated by blank lines
- Key Findings section must use bullet points (•)
- Strictly adhere to character limits for each section
**CONTENT REQUIREMENT:**
- Only use provided data points
- Do not fabricate or speculate information
- If data is insufficient for a section, provide a brief note or skip
- Maintain professional, analytical tone throughout

View File

@@ -9,7 +9,9 @@
## 任务目标
作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。
**数据关系**: statement_databasets中的statement_id对应data中的记录statement_created_at为用户输入时间。
**处理模式**:
- memory_verify=false: 仅处理数据冲突
- memory_verify=true: 处理数据冲突 + 隐私脱敏
@@ -109,8 +111,7 @@
- 隐私保护优先: 所有输出记录必须完成隐私脱敏
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
- 输出的结果reflexion字段中的reason字段和solution不允许含有expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容
如果是FACT只记录事实冲突相关的数据如果是TIME只记录时间冲突相关的数据如果是HYBRID则记录所有冲突相关的数据
- 输出的结果reflexion字段中的reason字段和solution不允许含有expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容
**变更记录格式**:
```json
@@ -157,9 +158,8 @@
"conflict": true
},
"reflexion": {
"reason": "该冲突类型的原因分析如果是FACT就是存在事实冲突分析该冲突原因如果是TIME就是存在时间冲突分析该冲突原因如果是HYBRID可以输出存在时间与事实的混合冲突再添加上原因分析
不可以随意分配冲突类型以及原因不允许输出字段比如statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict等类似这种",
"solution": "该冲突类型的解决方案不允许输出字段比如statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict等类似这种"
"reason": "该冲突类型的原因分析",
"solution": "该冲突类型的解决方案"
},
"resolved": {
"original_memory_id": "被设为失效的记忆id",
@@ -182,5 +182,4 @@
- **resolved.change**: 包含详细变更信息
- 无需修改的冲突类型resolved为null
- 与baseline不匹配的冲突类型不包含在results中
模式参考: {{ json_schema }}
模式参考: {{ json_schema }}

View File

@@ -1,114 +0,0 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
===Task===
Your task is to generate a comprehensive user profile based on the provided entities and statements. The profile should include four distinct sections that capture different aspects of the user's identity and characteristics.
===Inputs===
{% if user_id %}
- User ID: {{ user_id }}
{% endif %}
{% if entities %}
- Core Entities & Frequency: {{ entities }}
{% endif %}
{% if statements %}
- Representative Statement Samples: {{ statements }}
{% endif %}
===Profile Generation Requirements===
**General Guidelines:**
1. Base your analysis ONLY on the provided data - do not speculate or fabricate information
2. Use objective third-person descriptions with a restrained and neutral tone
3. Avoid excessive adjectives and empty phrases
4. Strictly follow the output format specified below
**Section-Specific Requirements:**
1. **Basic Introduction** (4-5 sentences, max 150 Chinese characters)
- Focus on: identity, occupation, location, and other basic demographic information
- Provide factual background about who the user is
2. **Personality Traits** (2-3 sentences, max 80 Chinese characters)
- Focus on: personality characteristics, behavioral habits, communication style
- Describe observable patterns in how the user interacts and behaves
3. **Core Values** (1-2 sentences, max 50 Chinese characters)
- Focus on: values, beliefs, goals, and aspirations
- Capture what matters most to the user and what drives their decisions
4. **One-Sentence Summary** (1 sentence, max 40 Chinese characters)
- Provide a highly condensed characterization of the user's core traits
- Similar to a personal tagline or motto that captures their essence
===Output Format (MUST STRICTLY FOLLOW)===
【基本介绍】
[4-5 sentences describing the user's basic identity, occupation, and location]
【性格特点】
[2-3 sentences describing the user's personality traits, behavioral habits, and communication style]
【核心价值观】
[1-2 sentences describing the user's values, beliefs, and goals]
【一句话总结】
[1 sentence providing a highly condensed summary of the user's core characteristics]
===Example===
Example Input:
- User ID: user_12345
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
Example Output:
【基本介绍】
我是张三一名充满热情的高级产品经理。在过去的5年里我专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
【性格特点】
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
【核心价值观】
用户至上、数据驱动、持续学习、团队协作
【一句话总结】
"让每一个产品决策都充满温度。"
===End of Example===
===Internal Quality Checks (DO NOT OUTPUT)===
Before generating your final output, internally verify:
1. All content is grounded in provided data (no fabrication)
2. Format follows the specified structure with correct headers
3. Tone is objective, third-person, and neutral
4. All four sections are complete and within character limits
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
===Output Requirements===
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
**LANGUAGE REQUIREMENT:**
- The output language should ALWAYS be Chinese (Simplified)
- All section content must be in Chinese
- Section headers must use the specified Chinese format: 【基本介绍】【性格特点】【核心价值观】【一句话总结】
**FORMAT REQUIREMENT:**
- Each section must start with its header on a new line
- Content follows immediately after the header
- Sections are separated by blank lines
- Strictly adhere to character limits for each section
- **DO NOT include any text after the 【一句话总结】 section**
- **DO NOT output reflection steps, self-review, or verification notes**

View File

@@ -54,9 +54,9 @@ async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
Returns:
符合反思范围的记忆数据列表。
"""
if REFLEXION_RANGE == "partial":
if REFLEXION_RANGE == "retrieval":
return await get_data(host_id)
elif REFLEXION_RANGE == "all":
elif REFLEXION_RANGE == "database":
return []
else:
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")

View File

@@ -64,8 +64,8 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None, **kwargs):
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
app_id = os.environ.get("TEXTLN_APP_ID", "")
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
app_id = os.environ.get("TEXTLN_APP_ID", "fa3f24380683ad53e6c620c0f0878a09")
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "6130caac9aabc6eb26433758d7898f4a")
pdf_parser = TextLnParser(textln_api=textln_api, app_id=app_id, secret_code=secret_code)
sections, tables = pdf_parser.parse_pdf(
@@ -672,7 +672,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
excel_parser = ExcelParser()
if parser_config.get("html4excel"):
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
parser_config["chunk_token_num"] = 0
else:
sections = [(_, "") for _ in excel_parser(binary) if _]
parser_config["chunk_token_num"] = 12800

View File

@@ -5,7 +5,6 @@ from io import BytesIO
import pandas as pd
from openpyxl import Workbook, load_workbook
from PIL import Image
from app.core.rag.nlp import find_codec
@@ -29,7 +28,7 @@ class RAGExcelParser:
try:
file_like_object.seek(0)
df = pd.read_csv(file_like_object, on_bad_lines='skip')
df = pd.read_csv(file_like_object)
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_csv:
@@ -66,6 +65,7 @@ class RAGExcelParser:
# if contains multiple sheets use _dataframes_to_workbook
if isinstance(df, dict) and len(df) > 1:
return RAGExcelParser._dataframes_to_workbook(df)
df = RAGExcelParser._clean_dataframe(df)
wb = Workbook()
ws = wb.active
@@ -77,14 +77,15 @@ class RAGExcelParser:
for row_num, row in enumerate(df.values, 2):
for col_num, value in enumerate(row, 1):
ws.cell(row=row_num, column=col_num, value=value)
return wb
return wb
@staticmethod
def _dataframes_to_workbook(dfs: dict):
wb = Workbook()
default_sheet = wb.active
wb.remove(default_sheet)
for sheet_name, df in dfs.items():
df = RAGExcelParser._clean_dataframe(df)
ws = wb.create_sheet(title=sheet_name)
@@ -95,52 +96,6 @@ class RAGExcelParser:
ws.cell(row=row_num, column=col_num, value=value)
return wb
@staticmethod
def _extract_images_from_worksheet(ws, sheetname=None):
"""
Extract images from a worksheet and enrich them with vision-based descriptions.
Returns: List[dict]
"""
images = getattr(ws, "_images", [])
if not images:
return []
raw_items = []
for img in images:
try:
img_bytes = img._data()
pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
anchor = img.anchor
if hasattr(anchor, "_from") and hasattr(anchor, "_to"):
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
r2, c2 = anchor._to.row + 1, anchor._to.col + 1
if r1 == r2 and c1 == c2:
span = "single_cell"
else:
span = "multi_cell"
else:
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
r2, c2 = r1, c1
span = "single_cell"
item = {
"sheet": sheetname or ws.title,
"image": pil_img,
"image_description": "",
"row_from": r1,
"col_from": c1,
"row_to": r2,
"col_to": c2,
"span_type": span,
}
raw_items.append(item)
except Exception:
continue
return raw_items
def html(self, fnm, chunk_rows=256):
from html import escape
@@ -173,7 +128,7 @@ class RAGExcelParser:
tb = ""
tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows: min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
tb += "<tr>"
for i, c in enumerate(r):
if c.value is None:
@@ -196,7 +151,7 @@ class RAGExcelParser:
except Exception as e:
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
file_like_object.seek(0)
df = pd.read_csv(file_like_object, on_bad_lines='skip')
df = pd.read_csv(file_like_object)
df = df.replace(r"^\s*$", "", regex=True)
return df.to_markdown(index=False)
@@ -214,35 +169,19 @@ class RAGExcelParser:
continue
if not rows:
continue
# 获取表头
ti = list(rows[0])
header_fields = []
for cell in ti:
if cell.value: # 只添加有值的表头
header_fields.append(str(cell.value))
# 如果有数据行,处理数据行;否则只处理表头
data_rows = rows[1:]
if data_rows:
for r in data_rows:
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
else:
# 只有表头的情况
if header_fields:
line = "; ".join(header_fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
for r in list(rows[1:]):
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
return res
@staticmethod
@@ -250,14 +189,14 @@ class RAGExcelParser:
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
try:
ws = wb[sheetname]
total += len(list(ws.rows))
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
try:
ws = wb[sheetname]
total += len(list(ws.rows))
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:

View File

@@ -196,7 +196,7 @@ class EntityResolution(Extractor):
ans_list = []
records = [r.strip() for r in results.split(record_delimiter)]
for record in records:
pattern_int = fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
match_int = re.search(pattern_int, record)
res_int = int(str(match_int.group(1) if match_int else '0'))
if res_int > records_length:

View File

@@ -4,7 +4,6 @@ from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
import time
import trio
from app.core.rag.common.misc_utils import get_uuid
@@ -263,21 +262,21 @@ class KGSearch(Dealer):
relas = ""
return {
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token),
"vector": None,
"metadata": {
"doc_id": get_uuid(),
"file_id": "",
"file_name": "Related content in Knowledge Graph",
"file_created_at": int(time.time() * 1000),
"chunk_id": get_uuid(),
"content_ltks": "",
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"document_id": "",
"knowledge_id": kb_ids,
"sort_id": 0,
"status": 1,
"score": 1
},
"children": None
}
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval

View File

@@ -448,7 +448,7 @@ if __name__ == "__main__":
# 准备配置vision_model信息
# 初始化 QWenCV
vision_model = QWenCV(
key="",
key="sk-8e9e40cd171749858ce2d3722ea75669",
model_name="qwen-vl-max",
lang="Chinese", # 默认使用中文
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"

View File

@@ -26,8 +26,6 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
from app.core.rag.common.string_utils import remove_redundant_spaces
from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
def knowledge_retrieval(
@@ -50,7 +48,6 @@ def knowledge_retrieval(
- merge_strategy: "weight" or other strategies
- reranker_id: UUID of the reranker to use
- reranker_top_k: int
- use_graph: bool, whether to use a graph
Returns:
Rearranged document block list (in descending order of relevance)
@@ -62,7 +59,6 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
file_names_filter = []
if user_ids:
@@ -71,10 +67,6 @@ def knowledge_retrieval(
if not knowledge_bases:
return []
kb_ids = []
workspace_ids = []
chat_model = None
embedding_model = None
all_results = []
# Search each knowledge base
for kb_config in knowledge_bases:
@@ -95,22 +87,6 @@ def knowledge_retrieval(
else:
continue
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]:
@@ -160,12 +136,6 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
return all_results
except Exception as e:

View File

@@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection):
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
# similarity=similarity
similarity=similarity,
)
if bqry and rank_feature:

View File

@@ -1,7 +1,7 @@
"""工具管理核心模块"""
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from app.core.tools.langchain_adapter import LangchainAdapter
from .base import BaseTool, ToolResult, ToolParameter
from .langchain_adapter import LangchainAdapter
# 可选导入,避免导入错误
try:

View File

@@ -191,14 +191,10 @@ class BaseTool(ABC):
execution_time=execution_time
)
def to_langchain_tool(self, operation: Optional[str] = None):
"""转换为Langchain工具格式
Args:
operation: 特定操作(适用于有操作的工具)
"""
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self, operation)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -1,11 +1,11 @@
"""内置工具模块"""
from app.core.tools.builtin.base import BuiltinTool
from app.core.tools.builtin.datetime_tool import DateTimeTool
from app.core.tools.builtin.json_tool import JsonTool
from app.core.tools.builtin.baidu_search_tool import BaiduSearchTool
from app.core.tools.builtin.mineru_tool import MinerUTool
from app.core.tools.builtin.textin_tool import TextInTool
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
__all__ = [
"BuiltinTool",

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.core.tools.builtin.base import BuiltinTool
from .base import BuiltinTool
class BaiduSearchTool(BuiltinTool):
@@ -110,7 +110,7 @@ class BaiduSearchTool(BuiltinTool):
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result["results"],
data=result,
execution_time=execution_time
)

View File

@@ -5,7 +5,7 @@ from typing import List
import pytz
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.tools.builtin.base import BuiltinTool
from .base import BuiltinTool
class DateTimeTool(BuiltinTool):
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
),
ToolParameter(
name="input_value",
@@ -95,7 +95,7 @@ class DateTimeTool(BuiltinTool):
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result["result_data"],
data=result,
execution_time=execution_time
)
@@ -123,14 +123,12 @@ class DateTimeTool(BuiltinTool):
utc_now = datetime.now(timezone.utc)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat(),
"result_data": {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timestamp_ms": int(now.timestamp() * 1000),
"utc_datetime": utc_now.strftime(output_format),
}
"timestamp_ms": int(now.timestamp() * 1000),
"utc_datetime": utc_now.strftime(output_format)
}
@staticmethod
@@ -150,8 +148,7 @@ class DateTimeTool(BuiltinTool):
"original": input_value,
"formatted": dt.strftime(output_format),
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat(),
"result_data": dt.strftime(output_format)
"iso_format": dt.isoformat()
}
@staticmethod
@@ -192,8 +189,7 @@ class DateTimeTool(BuiltinTool):
"original_timezone": from_timezone,
"converted": converted_dt.strftime(output_format),
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp()),
"result_data": converted_dt.strftime(output_format)
"timestamp": int(converted_dt.timestamp())
}
@staticmethod
@@ -223,8 +219,7 @@ class DateTimeTool(BuiltinTool):
"timestamp": timestamp,
"datetime": dt.strftime(output_format),
"timezone": timezone_str,
"iso_format": dt.isoformat(),
"result_data": dt.strftime(output_format)
"iso_format": dt.isoformat()
}
@staticmethod
@@ -254,8 +249,7 @@ class DateTimeTool(BuiltinTool):
"datetime": input_value,
"timezone": timezone_str,
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat(),
"result_data": int(dt.timestamp())
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
@@ -293,8 +287,7 @@ class DateTimeTool(BuiltinTool):
"calculation": calculation,
"result": calculated_dt.strftime(output_format),
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp()),
"result_data": calculated_dt.strftime(output_format)
"timestamp": int(calculated_dt.timestamp())
}
@staticmethod

View File

@@ -7,7 +7,7 @@ import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.core.tools.builtin.base import BuiltinTool
from .base import BuiltinTool
class JsonTool(BuiltinTool):
@@ -29,7 +29,8 @@ class JsonTool(BuiltinTool):
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["insert", "replace", "delete", "parse"]
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
"extract", "insert", "replace", "delete", "parse"]
),
ToolParameter(
name="input_data",
@@ -69,7 +70,7 @@ class JsonTool(BuiltinTool):
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式用于insert、replace、delete、parse操作$.user.name或users[0].name",
description="JSON路径表达式用于extract、insert、replace、delete、parse操作$.user.name或users[0].name",
required=False
),
ToolParameter(
@@ -136,7 +137,7 @@ class JsonTool(BuiltinTool):
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result["result_data"],
data=result,
execution_time=execution_time
)
@@ -671,8 +672,7 @@ class JsonTool(BuiltinTool):
"success": True,
"value": current,
"value_type": type(current).__name__,
"value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current),
"result_data": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current)
"value_json": json.dumps(current, indent=2, ensure_ascii=False) if isinstance(current, (dict, list)) else str(current)
}
except (KeyError, IndexError, TypeError) as e:
@@ -681,8 +681,7 @@ class JsonTool(BuiltinTool):
"json_path": json_path,
"success": False,
"error": str(e),
"value": None,
"result_data": None
"value": None
}
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.core.tools.builtin.base import BuiltinTool
from .base import BuiltinTool
class MinerUTool(BuiltinTool):

View File

@@ -1,216 +0,0 @@
"""操作工具 - 为特定操作创建的工具包装器"""
from typing import List
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.models import ToolType
class OperationTool(BaseTool):
"""操作工具 - 包装基础工具的特定操作"""
def __init__(self, base_tool: BaseTool, operation: str):
self.base_tool = base_tool
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
def description(self) -> str:
return f"{self.base_tool.description} - {self.operation}"
@property
def parameters(self) -> List[ToolParameter]:
"""返回特定操作的参数"""
if self.base_tool.name == 'datetime_tool':
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
def _get_datetime_params(self) -> List[ToolParameter]:
"""获取datetime_tool特定操作的参数"""
if self.operation == "now":
return [
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "format":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "convert_timezone":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
elif self.operation == "timestamp_to_datetime":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
def _get_json_params(self) -> List[ToolParameter]:
"""获取json_tool特定操作的参数"""
base_params = [
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
)
]
if self.operation == "insert":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert操作",
required=True
)
]
elif self.operation == "replace":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=True
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=True
)
]
elif self.operation == "delete":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
elif self.operation == "parse":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
else:
return base_params
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数
kwargs["operation"] = self.operation
return await self.base_tool.execute(**kwargs)

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.core.tools.builtin.base import BuiltinTool
from .base import BuiltinTool
class TextInTool(BuiltinTool):

View File

@@ -1,8 +1,8 @@
"""自定义工具模块"""
from app.core.tools.custom.base import CustomTool
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
from app.core.tools.custom.auth_manager import AuthManager
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
__all__ = [
"CustomTool",

View File

@@ -1,5 +1,4 @@
"""自定义工具基类"""
import json
import time
from typing import Dict, Any, List, Optional
import aiohttp
@@ -136,13 +135,6 @@ class CustomTool(BaseTool):
if not self.schema_content:
return operations
if isinstance(self.schema_content, str):
try:
self.schema_content = json.loads(self.schema_content)
except json.JSONDecodeError:
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
return operations
paths = self.schema_content.get("paths", {})

View File

@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -210,9 +213,7 @@ class OpenAPISchemaParser:
if not isinstance(operation, dict):
continue
summary = operation.get("summary", "")
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
@@ -222,7 +223,7 @@ class OpenAPISchemaParser:
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": summary if summary else operation_id,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),

View File

@@ -21,35 +21,24 @@ class LangchainToolWrapper(LangchainBaseTool):
# 内部工具实例
tool_instance: BaseTool = Field(..., description="内部工具实例")
# 特定操作(用于自定义工具)
operation: Optional[str] = Field(None, description="特定操作")
class Config:
arbitrary_types_allowed = True
def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs):
def __init__(self, tool_instance: BaseTool, **kwargs):
"""初始化Langchain工具包装器
Args:
tool_instance: 内部工具实例
operation: 特定操作(用于自定义工具)
"""
# 动态创建参数schema
args_schema = LangchainAdapter._create_pydantic_schema(
tool_instance.parameters, operation
)
# 构建工具名称
tool_name = tool_instance.name
if operation:
tool_name = f"{tool_instance.name}_{operation}"
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
super().__init__(
name=tool_name,
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
tool_instance=tool_instance,
operation=operation,
_tool_instance=tool_instance,
**kwargs
)
@@ -69,12 +58,8 @@ class LangchainToolWrapper(LangchainBaseTool):
) -> str:
"""异步执行工具"""
try:
# 如果有特定操作,添加到参数中
if self.operation:
kwargs["operation"] = self.operation
# 执行内部工具
result = await self.tool_instance.safe_execute(**kwargs)
result = await self._tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
@@ -88,82 +73,24 @@ class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
operation: 特定操作适用于有操作的工具或MCP工具名称
Returns:
Langchain兼容的工具包装器
"""
try:
# 处理MCP工具的特定工具名称
if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation:
# 为MCP工具创建特定工具名称的实例
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
elif operation and LangchainAdapter._tool_supports_operations(tool):
# 为支持多操作的工具创建特定操作实例
if tool.tool_type.value == "custom":
# 自定义工具直接传递operation参数
wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation)
else:
# 内置工具使用OperationTool包装
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
else:
# 单个工具
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def _tool_supports_operations(tool: BaseTool) -> bool:
"""检查工具是否支持多操作"""
# 内置工具中支持操作的工具
builtin_operation_tools = ['datetime_tool', 'json_tool']
# 检查内置工具
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
return True
# 检查自定义工具自定义工具通过解析OpenAPI schema支持多操作
if tool.tool_type.value == "custom":
# 检查工具是否有多个操作
if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1:
return True
# 或者检查参数中是否有operation参数
for param in tool.parameters:
if param.name == "operation" and param.enum:
return True
return False
@staticmethod
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
"""为特定操作创建工具实例"""
if base_tool.tool_type.value == "builtin":
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
else:
raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}")
@staticmethod
def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool:
"""为MCP工具创建指定工具名称的实例"""
mcp_tool.set_current_tool(tool_name)
return mcp_tool
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
@@ -183,19 +110,15 @@ class LangchainAdapter:
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
return converted_tools
@staticmethod
def _create_pydantic_schema(
parameters: List[ToolParameter],
operation: Optional[str] = None
) -> Type[BaseModel]:
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
"""根据工具参数创建Pydantic schema
Args:
parameters: 工具参数列表
operation: 特定操作(用于过滤参数)
Returns:
Pydantic模型类
@@ -204,12 +127,7 @@ class LangchainAdapter:
fields = {}
annotations = {}
# 如果指定了operation过滤掉operation参数
filtered_params = parameters
if operation:
filtered_params = [p for p in parameters if p.name != "operation"]
for param in filtered_params:
for param in parameters:
# 确定Python类型
python_type = LangchainAdapter._get_python_type(param.type)
@@ -232,7 +150,7 @@ class LangchainAdapter:
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
@@ -241,7 +159,7 @@ class LangchainAdapter:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["pattern"] = param.pattern
field_kwargs["regex"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type
@@ -251,10 +169,9 @@ class LangchainAdapter:
"ToolArgsSchema",
(BaseModel,),
{
"__module__": __name__,
"__annotations__": annotations,
"model_config": {"extra": "forbid"},
**fields
**fields,
"Config": type("Config", (), {"extra": "forbid"})
}
)

View File

@@ -1,20 +1,12 @@
"""MCP 工具模块 - Model Context Protocol 支持"""
"""MCP工具模块"""
# 主要类导出
from .base import MCPTool, MCPToolManager, MCPError
from .client import SimpleMCPClient, MCPConnectionError
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
__all__ = [
# 核心类
"MCPTool",
"MCPToolManager",
"MCPError",
# 客户端类
"SimpleMCPClient",
"MCPConnectionError",
# 服务管理(简化版)
"MCPClient",
"MCPConnectionPool",
"MCPServiceManager"
]

View File

@@ -1,9 +1,11 @@
"""MCP工具基类 - 整合版本"""
"""MCP工具基类"""
import time
from typing import List, Dict, Any
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -13,188 +15,215 @@ class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""根据工具名称返回对应参数"""
# 如果有指定的工具名称,从 available_tools 中获取参数
tool_name = getattr(self, '_current_tool_name', None)
if tool_name and self.available_tools:
for tool_info in self.available_tools:
if tool_info.get("tool_name") == tool_name:
arguments = tool_info.get("arguments", {})
return self._generate_parameters_from_schema(arguments)
"""工具参数定义"""
params = []
# 默认返回通用参数
return [
ToolParameter(
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="执行的工具名称",
required=True
),
description="调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数",
description="工具参数JSON对象",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
]
def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]:
"""从参数schema生成参数列表"""
properties = arguments.get("properties", {})
required_fields = arguments.get("required", [])
params = []
for param_name, param_def in properties.items():
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
params.append(ToolParameter(
name=param_name,
type=param_type,
description=param_def.get("description", f"参数: {param_name}"),
required=param_name in required_fields,
default=param_def.get("default"),
enum=param_def.get("enum"),
minimum=param_def.get("minimum"),
maximum=param_def.get("maximum")
))
])
return params
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
"""转换JSON Schema类型到ParameterType"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(json_type, ParameterType.STRING)
def set_current_tool(self, tool_name: str):
"""设置当前工具名称,用于获取特定参数"""
self._current_tool_name = tool_name
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise Exception("未指定工具名称")
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
from .client import SimpleMCPClient
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
client = SimpleMCPClient(self.server_url, self.connection_config)
async with client:
result = await client.call_tool(tool_name, arguments)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}")
return ToolResult.error_result(
error=str(e),
error_code="MCP_EXECUTION_ERROR",
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
class MCPError(Exception):
"""MCP 错误基类"""
pass
class MCPToolManager:
"""MCP 工具管理器 - 简化版本"""
def __init__(self, db=None):
self.db = db
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
async def discover_tools(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> tuple[bool, List[Dict[str, Any]], str | None]:
"""发现 MCP 服务器上的工具"""
try:
from .client import SimpleMCPClient
client = SimpleMCPClient(server_url, connection_config)
async with client:
tools = await client.list_tools()
# 缓存工具信息
self._tool_cache[server_url] = {
"tools": tools,
"connection_config": connection_config,
"last_updated": time.time()
}
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
return True, tools, None
except Exception as e:
error_msg = f"发现工具失败: {e}"
logger.error(error_msg)
return False, [], error_msg
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
execution_time=execution_time
)
async def test_tool_connection(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> Dict[str, Any]:
"""测试工具连接"""
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
from .client import SimpleMCPClient
from .client import MCPClient
client = SimpleMCPClient(server_url, connection_config)
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
except Exception as e:
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def _update_available_tools(self):
"""更新可用工具列表"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
await self._client.disconnect()
self._client = None
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
async with client:
tools = await client.list_tools()
return {
"success": True,
"tools_count": len(tools),
"tools": [tool.get("name") for tool in tools],
"message": "连接成功"
}
except Exception as e:
return {
"success": False,
"error": str(e),
"message": "连接失败"
"error": str(e)
}

View File

@@ -1,8 +1,9 @@
"""MCP客户端 - 简化版本"""
"""MCP客户端 - Model Context Protocol客户端实现"""
import asyncio
import json
import time
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
@@ -17,260 +18,139 @@ class MCPConnectionError(Exception):
pass
class SimpleMCPClient:
"""简化的 MCP 客户端"""
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30)
# 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://"))
self.is_sse = "/sse" in server_url.lower()
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests = {}
self._server_capabilities = {}
self._endpoint_url = None # SSE endpoint URL
self._sse_task = None
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.disconnect()
async def connect(self):
"""建立连接"""
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self.is_websocket:
await self._connect_websocket()
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
await self._connect_http()
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}")
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self):
"""断开连接"""
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if self._sse_task:
self._sse_task.cancel()
if self._websocket:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
if self._session:
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开连接失败: {e}")
logger.error(f"断开MCP服务器连接失败: {e}")
return False
async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
asyncio.create_task(self._handle_websocket_messages())
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
if self.is_sse:
await self._initialize_sse_session()
elif "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_sse_session(self):
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
try:
# 建立 SSE 连接
response = await self._session.get(self.server_url)
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
# 启动 SSE 读取任务
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
# 等待获取 endpoint URL
for _ in range(10):
if self._endpoint_url:
break
await asyncio.sleep(1)
if not self._endpoint_url:
raise MCPConnectionError("未能获取 endpoint URL")
# 发送 initialize 请求到 endpoint
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
init_response = await self._send_sse_request(init_request)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
result = init_response.get("result", {})
self._server_capabilities = result.get("capabilities", {})
# 发送 initialized 通知
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
async def _read_sse_stream(self, response):
"""读取 SSE 流"""
try:
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('event:'):
continue
if line.startswith('data:'):
data = line[5:].strip() # 去除 'data:' 后的空格
if not data or data == '[DONE]':
continue
try:
# 处理 endpoint 事件(相对路径或绝对路径)
if not self._endpoint_url:
# 如果是相对路径,拼接成完整 URL
if data.startswith('/'):
from urllib.parse import urlparse, urlunparse
parsed = urlparse(self.server_url)
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
else:
self._endpoint_url = data
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
continue
# 处理 message 事件
message = json.loads(data)
request_id = message.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"SSE 流读取错误: {e}")
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""通过 SSE endpoint 发送请求"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
request_id = request["id"]
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
async with self._session.post(self._endpoint_url, json=request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
return await asyncio.wait_for(future, timeout=self.timeout)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError("请求超时")
async def _send_sse_notification(self, notification: Dict[str, Any]):
"""发送通知(无需响应)"""
if not self._endpoint_url:
raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status != 200:
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
try:
async with self._session.post(self.server_url, json=init_request) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
init_response = await response.json()
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(self.server_url, json=initialized_notification):
pass
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
# 基础 headers
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 合并 connection_config 中的自定义 headers
custom_headers = self.connection_config.get("headers", {})
if custom_headers:
headers.update(custom_headers)
# 处理认证配置(认证 headers 优先级更高)
auth_config = self.connection_config.get("auth_config", {})
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
auth_config = self.connection_config.get("auth_config", {})
if auth_type == "bearer_token":
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "api_key":
key = auth_config.get("api_key")
header_name = auth_config.get("key_name", "X-API-Key")
if key:
headers[header_name] = key
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
@@ -281,99 +161,504 @@ class SimpleMCPClient:
return headers
async def _send_initialize(self):
"""发送初始化消息(WebSocket"""
init_message = {
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
"name": tool_name,
"arguments": arguments
}
}
await self._websocket.send(json.dumps(init_message))
response = await self._websocket.recv()
response_data = json.loads(response)
if "error" in response_data:
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
result = response_data.get("result", {})
self._server_capabilities = result.get("capabilities", {})
await self._websocket.send(json.dumps({
"jsonrpc": "2.0",
"method": "notifications/initialized"
}))
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request = {
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"id": self._get_next_request_id(),
"method": "tools/list"
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
result = response_data.get("result", {})
return result.get("tools", [])
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments}
}
if self.is_websocket:
await self._websocket.send(json.dumps(request))
response = await self._websocket.recv()
response_data = json.loads(response)
elif self.is_sse:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
if "error" in response_data:
error = response_data["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response_data.get("result", {})
def _get_request_id(self) -> int:
"""生成请求 ID"""
self._request_id += 1
return self._request_id
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
async for message in self._websocket:
data = json.loads(message)
request_id = data.get("id")
if request_id and request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
logger.info("WebSocket 连接已关闭")
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
if self.connection_type == "websocket":
request_id = str(request_data["id"])
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
await self._pending_requests.pop(request_id, None)
raise
except Exception as e:
logger.error(f"WebSocket 消息处理错误: {e}")
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status == 200:
return await response.json()
else:
async with self._session.post(
self.server_url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as root_response:
if root_response.status != 200:
error_text = await root_response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}

Some files were not shown because too many files have changed in this diff Show More