feat(memory, model): update multi-modal memory write and model list API
- Adjust multi-modal memory write behavior for text and visual data - Mask API keys in model list response to prevent exposure - Add capability-based filtering to the model list API
This commit is contained in:
@@ -42,6 +42,7 @@ def get_model_strategies():
|
|||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
|
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -74,10 +75,21 @@ def get_model_list(
|
|||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
|
capability_list = []
|
||||||
|
if capability is not None:
|
||||||
|
flat_capability = []
|
||||||
|
for item in capability:
|
||||||
|
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||||
|
flat_capability.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||||
|
capability_list = unique_flat_capability
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
capability=capability_list,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
search=search,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Depends,Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
|
|||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
analytics_community_graph_data,
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
from app.repositories.workspace_repository import WorkspaceRepository
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
@@ -45,9 +45,9 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的记忆洞察报告
|
获取缓存的记忆洞察报告
|
||||||
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的用户摘要
|
获取缓存的用户摘要
|
||||||
@@ -90,7 +90,7 @@ async def get_user_summary_api(
|
|||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
@@ -102,7 +102,7 @@ async def get_user_summary_api(
|
|||||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
@@ -117,10 +117,10 @@ async def get_user_summary_api(
|
|||||||
|
|
||||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||||
async def generate_cache_api(
|
async def generate_cache_api(
|
||||||
request: GenerateCacheRequest,
|
request: GenerateCacheRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
手动触发缓存生成
|
手动触发缓存生成
|
||||||
@@ -134,7 +134,7 @@ async def generate_cache_api(
|
|||||||
"""
|
"""
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -155,10 +155,12 @@ async def generate_cache_api(
|
|||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
@@ -209,9 +211,9 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||||
async def get_node_statistics_api(
|
async def get_node_statistics_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
api_logger.info(
|
||||||
|
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用新的记忆类型统计函数
|
# 调用新的记忆类型统计函数
|
||||||
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
|
|||||||
|
|
||||||
# 计算总数用于日志
|
# 计算总数用于日志
|
||||||
total_count = sum(item["count"] for item in result)
|
total_count = sum(item["count"] for item in result)
|
||||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
api_logger.info(
|
||||||
|
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||||
async def get_graph_data_api(
|
async def get_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
node_types: Optional[str] = None,
|
node_types: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
center_node_id: Optional[str] = None,
|
center_node_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -298,9 +303,9 @@ async def get_graph_data_api(
|
|||||||
|
|
||||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
async def get_community_graph_data_api(
|
async def get_community_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
|
|||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_end_user_profile(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
@@ -385,9 +390,9 @@ async def get_end_user_profile(
|
|||||||
|
|
||||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||||
async def update_end_user_profile(
|
async def update_end_user_profile(
|
||||||
profile_update: EndUserProfileUpdate,
|
profile_update: EndUserProfileUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
更新终端用户的基本信息
|
更新终端用户的基本信息
|
||||||
@@ -417,7 +422,7 @@ async def update_end_user_profile(
|
|||||||
else:
|
else:
|
||||||
error_msg = result["error"]
|
error_msg = result["error"]
|
||||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
# 根据错误类型映射到合适的业务错误码
|
# 根据错误类型映射到合适的业务错误码
|
||||||
if error_msg == "终端用户不存在":
|
if error_msg == "终端用户不存在":
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||||
@@ -427,15 +432,18 @@ async def update_end_user_profile(
|
|||||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
async def memory_space_timeline_of_shared_memories(
|
||||||
current_user: User = Depends(get_current_user),
|
id: str, label: str,
|
||||||
db: Session = Depends(get_db),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
):
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id=current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
|||||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||||
|
|
||||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||||
async def memory_space_relationship_evolution(id: str, label: str,
|
async def memory_space_relationship_evolution(id: str, label: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
|||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "",
|
||||||
config_id: str = None
|
config_id: str = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
|
files = msg.get("file_content", [])
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ This module provides the main write function for executing the knowledge extract
|
|||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -19,10 +19,8 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
|
|||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import MemoryPerceptualModel
|
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
add_perceptual_dialogue_edges
|
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
@@ -36,7 +34,6 @@ async def write(
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
file_content: list[MemoryPerceptualModel],
|
|
||||||
ref_id: str = "",
|
ref_id: str = "",
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -47,7 +44,6 @@ async def write(
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
file_content: mutilmodal message list
|
|
||||||
ref_id: Reference ID, defaults to ""
|
ref_id: Reference ID, defaults to ""
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||||
"""
|
"""
|
||||||
@@ -142,9 +138,11 @@ async def write(
|
|||||||
all_chunk_nodes,
|
all_chunk_nodes,
|
||||||
all_statement_nodes,
|
all_statement_nodes,
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
|
all_perceptual_nodes,
|
||||||
all_statement_chunk_edges,
|
all_statement_chunk_edges,
|
||||||
all_statement_entity_edges,
|
all_statement_entity_edges,
|
||||||
all_entity_entity_edges,
|
all_entity_entity_edges,
|
||||||
|
all_perceptual_edges,
|
||||||
all_dedup_details,
|
all_dedup_details,
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
@@ -169,9 +167,11 @@ async def write(
|
|||||||
chunk_nodes=all_chunk_nodes,
|
chunk_nodes=all_chunk_nodes,
|
||||||
statement_nodes=all_statement_nodes,
|
statement_nodes=all_statement_nodes,
|
||||||
entity_nodes=all_entity_nodes,
|
entity_nodes=all_entity_nodes,
|
||||||
|
perceptual_nodes=all_perceptual_nodes,
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
|
perceptual_edges=all_perceptual_edges,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
@@ -230,34 +230,6 @@ async def write(
|
|||||||
finally:
|
finally:
|
||||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# Step 5: Save perceptual memory to Neo4j
|
|
||||||
step_start = time.time()
|
|
||||||
if file_content:
|
|
||||||
try:
|
|
||||||
pc_connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
created_ids = await add_perceptual_nodes(
|
|
||||||
perceptuals=file_content,
|
|
||||||
connector=pc_connector,
|
|
||||||
embedder_client=embedder_client,
|
|
||||||
)
|
|
||||||
# 如果有 ref_id,建立感知记忆与对话的关联
|
|
||||||
if ref_id and created_ids:
|
|
||||||
await add_perceptual_dialogue_edges(
|
|
||||||
perceptuals=file_content,
|
|
||||||
dialog_id=ref_id,
|
|
||||||
connector=pc_connector,
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await pc_connector.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
|
||||||
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
|
||||||
|
|
||||||
# Log total pipeline time
|
# Log total pipeline time
|
||||||
total_time = time.time() - pipeline_start
|
total_time = time.time() - pipeline_start
|
||||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from typing import Any, List
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Fix tokenizer parallelism warning
|
# Fix tokenizer parallelism warning
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
|||||||
"total_sub_chunks": len(sub_chunks),
|
"total_sub_chunks": len(sub_chunks),
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
else:
|
else:
|
||||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
|||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
|
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||||
|
|
||||||
|
|
||||||
class ChunkEdge(Edge):
|
class ChunkEdge(Edge):
|
||||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
|||||||
return parse_historical_datetime(v)
|
return parse_historical_datetime(v)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualEdge(Edge):
|
||||||
|
"""Edge connecting perceptual nodes to their source chunks
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Base class for all graph nodes in the knowledge graph.
|
"""Base class for all graph nodes in the knowledge graph.
|
||||||
|
|
||||||
@@ -555,19 +561,16 @@ class MemorySummaryNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MutlimodalNode(Node):
|
class PerceptualNode(Node):
|
||||||
"""Node representing a multimodal message in the knowledge graph.
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dialog_id: ID of the parent dialog
|
|
||||||
message_id: ID of the message
|
|
||||||
metadata: Additional message metadata
|
|
||||||
embedding: Optional embedding vector for the message
|
|
||||||
"""
|
"""
|
||||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
perceptual_type: int
|
||||||
message_id: str = Field(..., description="ID of the message")
|
file_path: str
|
||||||
summary: str = Field(..., description="The text content of the message")
|
file_name: str
|
||||||
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
file_ext: str
|
||||||
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
summary: str
|
||||||
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
keywords: list[str]
|
||||||
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
topic: str
|
||||||
|
domain: str
|
||||||
|
file_type: str
|
||||||
|
summary_embedding: list[float] | None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class TemporalValidityRange(BaseModel):
|
class TemporalValidityRange(BaseModel):
|
||||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
|||||||
content: str = Field(..., description="The content of the chunk as a string.")
|
content: str = Field(..., description="The content of the chunk as a string.")
|
||||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||||
|
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ from app.core.memory.models.graph_models import (
|
|||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
StatementChunkEdge,
|
StatementChunkEdge,
|
||||||
StatementEntityEdge,
|
StatementEntityEdge,
|
||||||
StatementNode
|
StatementNode,
|
||||||
|
PerceptualEdge,
|
||||||
|
PerceptualNode
|
||||||
)
|
)
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
@@ -170,9 +172,11 @@ class ExtractionOrchestrator:
|
|||||||
list[ChunkNode],
|
list[ChunkNode],
|
||||||
list[StatementNode],
|
list[StatementNode],
|
||||||
list[ExtractedEntityNode],
|
list[ExtractedEntityNode],
|
||||||
|
list[PerceptualNode],
|
||||||
list[StatementChunkEdge],
|
list[StatementChunkEdge],
|
||||||
list[StatementEntityEdge],
|
list[StatementEntityEdge],
|
||||||
list[EntityEntityEdge],
|
list[EntityEntityEdge],
|
||||||
|
list[PerceptualEdge],
|
||||||
dict
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@@ -259,9 +263,11 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
) = await self._create_nodes_and_edges(dialog_data_list)
|
) = await self._create_nodes_and_edges(dialog_data_list)
|
||||||
|
|
||||||
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
||||||
@@ -275,7 +281,16 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||||
|
|
||||||
result = await self._run_dedup_and_write_summary(
|
(
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
) = await self._run_dedup_and_write_summary(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
@@ -287,7 +302,18 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return result
|
return (
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
perceptual_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
||||||
@@ -1000,9 +1026,11 @@ class ExtractionOrchestrator:
|
|||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
List[StatementNode],
|
List[StatementNode],
|
||||||
List[ExtractedEntityNode],
|
List[ExtractedEntityNode],
|
||||||
|
List[PerceptualNode],
|
||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge]
|
List[EntityEntityEdge],
|
||||||
|
List[PerceptualEdge]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
创建图数据库节点和边
|
创建图数据库节点和边
|
||||||
@@ -1026,6 +1054,8 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edges = []
|
statement_chunk_edges = []
|
||||||
statement_entity_edges = []
|
statement_entity_edges = []
|
||||||
entity_entity_edges = []
|
entity_entity_edges = []
|
||||||
|
perceptual_nodes = []
|
||||||
|
perceptual_edges = []
|
||||||
|
|
||||||
# 用于去重的集合
|
# 用于去重的集合
|
||||||
entity_id_set = set()
|
entity_id_set = set()
|
||||||
@@ -1069,6 +1099,46 @@ class ExtractionOrchestrator:
|
|||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
|
logger.error(f"chunk file: {chunk.files}")
|
||||||
|
|
||||||
|
for p, file_type in chunk.files:
|
||||||
|
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
# 生成 summary embedding(如果有 embedder_client)
|
||||||
|
summary_embedding = None
|
||||||
|
if self.embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
perceptual = PerceptualNode(
|
||||||
|
name=f"Perceptual_{p.id}",
|
||||||
|
**{
|
||||||
|
"id": str(p.id),
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"perceptual_type": p.perceptual_type,
|
||||||
|
"file_path": p.file_path or "",
|
||||||
|
"file_name": p.file_name or "",
|
||||||
|
"file_ext": p.file_ext or "",
|
||||||
|
"summary": p.summary or "",
|
||||||
|
"keywords": content_meta.get("keywords", []),
|
||||||
|
"topic": content_meta.get("topic", ""),
|
||||||
|
"domain": content_meta.get("domain", ""),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
"file_type": file_type,
|
||||||
|
"summary_embedding": summary_embedding,
|
||||||
|
})
|
||||||
|
perceptual_nodes.append(perceptual)
|
||||||
|
perceptual_edges.append(PerceptualEdge(
|
||||||
|
source=perceptual.id,
|
||||||
|
target=chunk.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
))
|
||||||
|
|
||||||
# 处理每个陈述句
|
# 处理每个陈述句
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
@@ -1248,9 +1318,11 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_dedup_and_write_summary(
|
async def _run_dedup_and_write_summary(
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ class MemoryWriteNode(BaseNode):
|
|||||||
if not end_user_id:
|
if not end_user_id:
|
||||||
raise RuntimeError("End user id is required")
|
raise RuntimeError("End user id is required")
|
||||||
messages = []
|
messages = []
|
||||||
multimodal_memories = []
|
|
||||||
if self.typed_config.message:
|
if self.typed_config.message:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -104,19 +103,15 @@ class MemoryWriteNode(BaseNode):
|
|||||||
url=file_instence.value.url,
|
url=file_instence.value.url,
|
||||||
file_type=file_instence.value.origin_file_type
|
file_type=file_instence.value.origin_file_type
|
||||||
).model_dump())
|
).model_dump())
|
||||||
multimodal_memories.append({
|
|
||||||
"role": message.role,
|
|
||||||
"files": file_info
|
|
||||||
})
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._render_template(content, variable_pool)
|
"content": self._render_template(content, variable_pool),
|
||||||
|
"files": file_info
|
||||||
})
|
})
|
||||||
|
|
||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=messages,
|
message=messages,
|
||||||
file_messages=multimodal_memories,
|
|
||||||
config_id=str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
storage_type=state["memory_storage_type"],
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=state["user_rag_memory_id"]
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -408,6 +408,9 @@ class MemoryConfigRepository:
|
|||||||
"llm_id": db_config.llm_id,
|
"llm_id": db_config.llm_id,
|
||||||
"embedding_id": db_config.embedding_id,
|
"embedding_id": db_config.embedding_id,
|
||||||
"rerank_id": db_config.rerank_id,
|
"rerank_id": db_config.rerank_id,
|
||||||
|
"vision_id": db_config.vision_id,
|
||||||
|
"audio_id": db_config.audio_id,
|
||||||
|
"video_id": db_config.video_id,
|
||||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||||
"deep_retrieval": db_config.deep_retrieval,
|
"deep_retrieval": db_config.deep_retrieval,
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
|
||||||
from sqlalchemy import and_, or_, func, desc, select
|
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy import and_, or_, func, desc
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||||
from app.schemas.model_schema import (
|
from app.schemas.model_schema import (
|
||||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||||
ModelConfigQuery, ModelConfigQueryNew
|
ModelConfigQuery, ModelConfigQueryNew
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_db_logger
|
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -137,6 +138,9 @@ class ModelConfigRepository:
|
|||||||
type_values.append(ModelType.LLM)
|
type_values.append(ModelType.LLM)
|
||||||
filters.append(ModelConfig.type.in_(type_values))
|
filters.append(ModelConfig.type.in_(type_values))
|
||||||
|
|
||||||
|
if query.capability:
|
||||||
|
filters.append(ModelConfig.capability.contains(query.capability))
|
||||||
|
|
||||||
if query.is_active is not None:
|
if query.is_active is not None:
|
||||||
filters.append(ModelConfig.is_active == query.is_active)
|
filters.append(ModelConfig.is_active == query.is_active)
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||||
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
MEMORY_SUMMARY_NODE_SAVE
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||||
"""Delete all nodes in the database."""
|
"""Delete all nodes in the database."""
|
||||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
List of created node UUIDs or None if failed
|
List of created node UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not dialogues:
|
if not dialogues:
|
||||||
print("No dialogues to save")
|
logger.info("No dialogues to save")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -50,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating dialogue nodes: {e}")
|
logger.info(f"Error creating dialogue nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -69,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
List of created node UUIDs or None if failed
|
List of created node UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not statements:
|
if not statements:
|
||||||
print("No statements to save")
|
logger.info("No statements to save")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -122,11 +125,11 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating statement nodes: {e}")
|
logger.info(f"Error creating statement nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
List of created chunk UUIDs or None if failed
|
List of created chunk UUIDs or None if failed
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
print("No chunk nodes to add")
|
logger.info("No chunk nodes to add")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
created_uuids = [record["uuid"] for record in result]
|
created_uuids = [record["uuid"] for record in result]
|
||||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||||
return created_uuids
|
return created_uuids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating chunk nodes: {e}")
|
logger.info(f"Error creating chunk nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
async def add_memory_summary_nodes(
|
||||||
List[str]]:
|
summaries: List[MemorySummaryNode],
|
||||||
|
connector: Neo4jConnector
|
||||||
|
) -> Optional[List[str]]:
|
||||||
"""Add memory summary nodes to Neo4j in batch.
|
"""Add memory summary nodes to Neo4j in batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
List of created summary node ids or None if failed
|
List of created summary node ids or None if failed
|
||||||
"""
|
"""
|
||||||
if not summaries:
|
if not summaries:
|
||||||
print("No memory summary nodes to add")
|
logger.info("No memory summary nodes to add")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -220,110 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
summaries=flattened
|
summaries=flattened
|
||||||
)
|
)
|
||||||
created_ids = [record.get("uuid") for record in result]
|
created_ids = [record.get("uuid") for record in result]
|
||||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||||
return created_ids
|
return created_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def add_perceptual_nodes(
|
|
||||||
perceptuals: list,
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
embedder_client=None,
|
|
||||||
) -> Optional[List[str]]:
|
|
||||||
"""Add perceptual memory nodes to Neo4j in batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
|
||||||
connector: Neo4j connector instance
|
|
||||||
embedder_client: Optional embedder client for generating summary embeddings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of created node UUIDs or None if failed
|
|
||||||
"""
|
|
||||||
if not perceptuals:
|
|
||||||
print("No perceptual nodes to add")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
flattened = []
|
|
||||||
for p in perceptuals:
|
|
||||||
meta = p.meta_data or {}
|
|
||||||
content_meta = meta.get("content", {})
|
|
||||||
|
|
||||||
# 生成 summary embedding(如果有 embedder_client)
|
|
||||||
summary_embedding = None
|
|
||||||
if embedder_client and p.summary:
|
|
||||||
try:
|
|
||||||
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
|
||||||
except Exception as emb_err:
|
|
||||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
|
||||||
|
|
||||||
flattened.append({
|
|
||||||
"id": str(p.id),
|
|
||||||
"end_user_id": str(p.end_user_id),
|
|
||||||
"perceptual_type": p.perceptual_type,
|
|
||||||
"file_path": p.file_path or "",
|
|
||||||
"file_name": p.file_name or "",
|
|
||||||
"file_ext": p.file_ext or "",
|
|
||||||
"summary": p.summary or "",
|
|
||||||
"keywords": content_meta.get("keywords", []),
|
|
||||||
"topic": content_meta.get("topic", ""),
|
|
||||||
"domain": content_meta.get("domain", ""),
|
|
||||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
|
||||||
"summary_embedding": summary_embedding,
|
|
||||||
})
|
|
||||||
|
|
||||||
result = await connector.execute_query(
|
|
||||||
PERCEPTUAL_NODE_SAVE,
|
|
||||||
perceptuals=flattened,
|
|
||||||
)
|
|
||||||
created_uuids = [record.get("uuid") for record in result]
|
|
||||||
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
|
||||||
return created_uuids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def add_perceptual_dialogue_edges(
|
|
||||||
perceptuals: list,
|
|
||||||
dialog_id: str,
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
) -> Optional[List[str]]:
|
|
||||||
"""Add edges between Perceptual nodes and Dialogue nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
perceptuals: List of MemoryPerceptualModel objects
|
|
||||||
dialog_id: The dialogue ID (or ref_id) to link to
|
|
||||||
connector: Neo4j connector instance
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of created edge element IDs or None if failed
|
|
||||||
"""
|
|
||||||
if not perceptuals or not dialog_id:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
edges = []
|
|
||||||
for p in perceptuals:
|
|
||||||
edges.append({
|
|
||||||
"perceptual_id": str(p.id),
|
|
||||||
"dialog_id": dialog_id,
|
|
||||||
"end_user_id": str(p.end_user_id),
|
|
||||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
result = await connector.execute_query(
|
|
||||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
|
||||||
edges=edges,
|
|
||||||
)
|
|
||||||
created_ids = [record.get("uuid") for record in result]
|
|
||||||
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
|
||||||
return created_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1003,60 +1003,69 @@ RETURN DISTINCT
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
Graph_Node_query = """
|
Graph_Node_query = """
|
||||||
MATCH (n:MemorySummary)
|
MATCH (n:MemorySummary)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
0 AS priority
|
0 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Dialogue)
|
MATCH (n:Dialogue)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
1 AS priority
|
1 AS priority
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Statement)
|
MATCH (n:Statement)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
1 AS priority
|
1 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:ExtractedEntity)
|
MATCH (n:ExtractedEntity)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
2 AS priority
|
2 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
UNION ALL
|
UNION ALL
|
||||||
|
|
||||||
MATCH (n:Chunk)
|
MATCH (n:Chunk)
|
||||||
WHERE n.end_user_id = $end_user_id
|
WHERE n.end_user_id = $end_user_id
|
||||||
RETURN
|
RETURN
|
||||||
elementId(n) AS id,
|
elementId(n) AS id,
|
||||||
labels(n) AS labels,
|
labels(n) AS labels,
|
||||||
properties(n) AS properties,
|
properties(n) AS properties,
|
||||||
3 AS priority
|
3 AS priority
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
"""
|
UNION ALL
|
||||||
|
MATCH (n:Perceptual)
|
||||||
|
WHERE n.end_user_id = $end_user_id
|
||||||
|
RETURN
|
||||||
|
elementId(n) AS id,
|
||||||
|
labels(n) AS labels,
|
||||||
|
properties(n) AS properties,
|
||||||
|
4 AS priority
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||||
@@ -1340,19 +1349,19 @@ SET n += {
|
|||||||
topic: p.topic,
|
topic: p.topic,
|
||||||
domain: p.domain,
|
domain: p.domain,
|
||||||
created_at: p.created_at,
|
created_at: p.created_at,
|
||||||
|
file_type: p.file_type,
|
||||||
summary_embedding: p.summary_embedding
|
summary_embedding: p.summary_embedding
|
||||||
}
|
}
|
||||||
RETURN n.id AS uuid
|
RETURN n.id AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 感知记忆与对话的关联边
|
# 感知记忆与对话的关联边
|
||||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
PERCEPTUAL_CHUNK_EDGE_SAVE = """
|
||||||
UNWIND $edges AS edge
|
UNWIND $edges AS edge
|
||||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||||
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
|
||||||
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
|
||||||
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||||
SET r.end_user_id = edge.end_user_id,
|
|
||||||
r.created_at = edge.created_at
|
r.created_at = edge.created_at
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
|
|||||||
StatementNode,
|
StatementNode,
|
||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
|
PerceptualNode,
|
||||||
|
PerceptualEdge,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def save_entities_and_relationships(
|
async def save_entities_and_relationships(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save entities and their relationships using graph models"""
|
"""Save entities and their relationships using graph models"""
|
||||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||||
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
|
|||||||
|
|
||||||
|
|
||||||
async def save_chunk_nodes(
|
async def save_chunk_nodes(
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save chunk nodes using graph models"""
|
"""Save chunk nodes using graph models"""
|
||||||
if not chunk_nodes:
|
if not chunk_nodes:
|
||||||
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
|
|||||||
|
|
||||||
|
|
||||||
async def save_statement_chunk_edges(
|
async def save_statement_chunk_edges(
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save statement-chunk edges using graph models"""
|
"""Save statement-chunk edges using graph models"""
|
||||||
if not statement_chunk_edges:
|
if not statement_chunk_edges:
|
||||||
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
|
|||||||
|
|
||||||
|
|
||||||
async def save_statement_entity_edges(
|
async def save_statement_entity_edges(
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
):
|
):
|
||||||
"""Save statement-entity edges using graph models"""
|
"""Save statement-entity edges using graph models"""
|
||||||
if not statement_entity_edges:
|
if not statement_entity_edges:
|
||||||
@@ -142,7 +147,7 @@ async def save_statement_entity_edges(
|
|||||||
if all_se_edges:
|
if all_se_edges:
|
||||||
try:
|
try:
|
||||||
await connector.execute_query(
|
await connector.execute_query(
|
||||||
STATEMENT_ENTITY_EDGE_SAVE,
|
STATEMENT_ENTITY_EDGE_SAVE,
|
||||||
relationships=all_se_edges
|
relationships=all_se_edges
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
|
perceptual_nodes: List[PerceptualNode],
|
||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
|
perceptual_edges: List[PerceptualEdge],
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
chunk_nodes: List of ChunkNode objects to save
|
chunk_nodes: List of ChunkNode objects to save
|
||||||
statement_nodes: List of StatementNode objects to save
|
statement_nodes: List of StatementNode objects to save
|
||||||
entity_nodes: List of ExtractedEntityNode objects to save
|
entity_nodes: List of ExtractedEntityNode objects to save
|
||||||
|
perceptual_nodes: List of PerceptualNode objects to save
|
||||||
entity_edges: List of EntityEntityEdge objects to save
|
entity_edges: List of EntityEntityEdge objects to save
|
||||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||||
|
perceptual_edges: List of PerceptualEdge objects to save
|
||||||
connector: Neo4j connector instance
|
connector: Neo4j connector instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||||
dialogue_uuids = [record["uuid"] async for record in result]
|
dialogue_uuids = [record["uuid"] async for record in result]
|
||||||
results['dialogues'] = dialogue_uuids
|
results['dialogues'] = dialogue_uuids
|
||||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||||
|
|
||||||
# 2. Save all chunk nodes in batch
|
# 2. Save all chunk nodes in batch
|
||||||
if chunk_nodes:
|
if chunk_nodes:
|
||||||
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
results['chunks'] = chunk_uuids
|
results['chunks'] = chunk_uuids
|
||||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||||
|
|
||||||
|
if perceptual_nodes:
|
||||||
|
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
|
||||||
|
perceptual_data = [node.model_dump() for node in perceptual_nodes]
|
||||||
|
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
|
||||||
|
perceptual_uuids = [record["uuid"] async for record in result]
|
||||||
|
results["perceptuals"] = perceptual_uuids
|
||||||
|
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
|
||||||
|
|
||||||
# 3. Save all statement nodes in batch
|
# 3. Save all statement nodes in batch
|
||||||
if statement_nodes:
|
if statement_nodes:
|
||||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||||
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
results['statement_entity_edges'] = se_uuids
|
results['statement_entity_edges'] = se_uuids
|
||||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||||
|
|
||||||
|
if perceptual_edges:
|
||||||
|
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
|
||||||
|
perceptual_edge_data = []
|
||||||
|
for edge in perceptual_edges:
|
||||||
|
print(edge.source, edge.target)
|
||||||
|
perceptual_edge_data.append({
|
||||||
|
"perceptual_id": edge.source,
|
||||||
|
"chunk_id": edge.target,
|
||||||
|
"end_user_id": edge.end_user_id,
|
||||||
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
|
})
|
||||||
|
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
|
||||||
|
perceptual_edges_uuids = [record["uuid"] async for record in result]
|
||||||
|
results['perceptual_chunk_edges'] = perceptual_edges_uuids
|
||||||
|
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
|
|
||||||
|
|
||||||
def schedule_clustering_after_write(
|
def schedule_clustering_after_write(
|
||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
写入 Neo4j 成功后,调度后台聚类任务。
|
写入 Neo4j 成功后,调度后台聚类任务。
|
||||||
@@ -325,14 +358,15 @@ def schedule_clustering_after_write(
|
|||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||||
|
embedding_model_id=embedding_model_id))
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
new_entity_ids: List[str],
|
new_entity_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
|
|||||||
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
|
|||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
api_keys: List["ModelApiKey"] = []
|
api_keys: List["ModelApiKey"] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||||||
|
if not key or len(key) <= prefix + suffix:
|
||||||
|
return "*" * len(key)
|
||||||
|
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||||||
|
|
||||||
@field_validator("api_keys", mode="after")
|
@field_validator("api_keys", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||||
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
|
|||||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("api_keys", when_used="json")
|
||||||
|
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||||||
|
result = []
|
||||||
|
for api_key in api_keys:
|
||||||
|
data = api_key.model_dump()
|
||||||
|
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||||||
|
result.append(data)
|
||||||
|
return result
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase):
|
|||||||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
||||||
self.model_config_ids = [
|
self.model_config_ids = [
|
||||||
mc.id for mc in self.model_configs
|
mc.id for mc in self.model_configs
|
||||||
if hasattr(mc, 'id')
|
if hasattr(mc, 'id')
|
||||||
and not getattr(mc, 'is_composite', False)
|
and not getattr(mc, 'is_composite', False)
|
||||||
and getattr(mc, 'name', None) == self.model_name
|
and getattr(mc, 'name', None) == self.model_name
|
||||||
]
|
]
|
||||||
# 情况2:字典列表
|
# 情况2:字典列表
|
||||||
elif isinstance(self.model_configs, list):
|
elif isinstance(self.model_configs, list):
|
||||||
self.model_config_ids = [
|
self.model_config_ids = [
|
||||||
mc['id'] if isinstance(mc, dict) else mc.id
|
mc['id'] if isinstance(mc, dict) else mc.id
|
||||||
for mc in self.model_configs
|
for mc in self.model_configs
|
||||||
if ((isinstance(mc, dict)
|
if ((isinstance(mc, dict)
|
||||||
and 'id' in mc
|
and 'id' in mc
|
||||||
and not mc.get('is_composite', False)
|
and not mc.get('is_composite', False)
|
||||||
and mc.get('name') == self.model_name) or
|
and mc.get('name') == self.model_name) or
|
||||||
(hasattr(mc, 'id')
|
(hasattr(mc, 'id')
|
||||||
and not getattr(mc, 'is_composite', False)
|
and not getattr(mc, 'is_composite', False)
|
||||||
and getattr(mc, 'name', None) == self.model_name))
|
and getattr(mc, 'name', None) == self.model_name))
|
||||||
]
|
]
|
||||||
@@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase):
|
|||||||
validate_assignment=True # 确保赋值触发校验
|
validate_assignment=True # 确保赋值触发校验
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
@field_serializer("updated_at", when_used="json")
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
|
|||||||
"""模型配置查询Schema"""
|
"""模型配置查询Schema"""
|
||||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||||
|
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
|
|||||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
|
|
||||||
|
|
||||||
class ModelMarketplace(BaseModel):
|
class ModelMarketplace(BaseModel):
|
||||||
"""模型广场响应Schema"""
|
"""模型广场响应Schema"""
|
||||||
llm_models: List[ModelConfig] = []
|
llm_models: List[ModelConfig] = []
|
||||||
@@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel):
|
|||||||
class ModelBase(BaseModel):
|
class ModelBase(BaseModel):
|
||||||
"""基础模型Schema"""
|
"""基础模型Schema"""
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
|
|||||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
"""模型信息Schema"""
|
"""模型信息Schema"""
|
||||||
model_name: str = Field(..., description="模型名称")
|
model_name: str = Field(..., description="模型名称")
|
||||||
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
|
|||||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||||
model_type: ModelType = Field(..., description="模型类型")
|
model_type: ModelType = Field(..., description="模型类型")
|
||||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||||
|
|
||||||
|
|||||||
@@ -274,7 +274,6 @@ class MemoryAgentService:
|
|||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
file_messages: list[dict],
|
|
||||||
config_id: Optional[uuid.UUID] | int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
@@ -287,7 +286,6 @@ class MemoryAgentService:
|
|||||||
Args:
|
Args:
|
||||||
end_user_id: Group identifier (also used as end_user_id)
|
end_user_id: Group identifier (also used as end_user_id)
|
||||||
messages: Message to write
|
messages: Message to write
|
||||||
files: Files to write
|
|
||||||
config_id: Configuration ID from database
|
config_id: Configuration ID from database
|
||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
@@ -348,15 +346,15 @@ class MemoryAgentService:
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
perceptual_serivce = MemoryPerceptualService(db)
|
perceptual_serivce = MemoryPerceptualService(db)
|
||||||
file_content = []
|
for message in messages:
|
||||||
for message in file_messages:
|
message["file_content"] = []
|
||||||
for file in message["files"]:
|
for file in message["files"]:
|
||||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
file=FileInput(**file)
|
file=FileInput(**file)
|
||||||
)
|
)
|
||||||
file_content.append(file_object)
|
message["file_content"].append((file_object, file["type"]))
|
||||||
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
try:
|
try:
|
||||||
@@ -368,7 +366,6 @@ class MemoryAgentService:
|
|||||||
await write_neo4j(
|
await write_neo4j(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
file_content=file_content,
|
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
ref_id='',
|
ref_id='',
|
||||||
language=language
|
language=language
|
||||||
@@ -380,19 +377,23 @@ class MemoryAgentService:
|
|||||||
if deleted:
|
if deleted:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||||
return self.writer_messages_deal(
|
for message in messages:
|
||||||
"success",
|
message["file_content"] = [
|
||||||
start_time,
|
perceptual[0].file_path for perceptual in message["file_content"]
|
||||||
end_user_id,
|
]
|
||||||
config_id,
|
return self.writer_messages_deal(
|
||||||
message_text,
|
"success",
|
||||||
{
|
start_time,
|
||||||
"status": "success",
|
end_user_id,
|
||||||
"data": messages,
|
config_id,
|
||||||
"config_id": memory_config.config_id,
|
message_text,
|
||||||
"config_name": memory_config.config_name
|
{
|
||||||
}
|
"status": "success",
|
||||||
)
|
"data": messages,
|
||||||
|
"config_id": memory_config.config_id,
|
||||||
|
"config_name": memory_config.config_name
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
|||||||
@@ -317,11 +317,11 @@ class MemoryPerceptualService:
|
|||||||
stmt = select(FileMetadata).where(
|
stmt = select(FileMetadata).where(
|
||||||
FileMetadata.id == file_id
|
FileMetadata.id == file_id
|
||||||
)
|
)
|
||||||
file = self.db.execute(stmt).scalar_one_or_none()
|
file_obj = self.db.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
if file:
|
if file_obj:
|
||||||
filename = file.file_name
|
filename = file_obj.file_name
|
||||||
file_ext = file.file_ext
|
file_ext = file_obj.file_ext
|
||||||
except ValueError:
|
except ValueError:
|
||||||
business_logger.debug(f"Remote file, file_id={filename}")
|
business_logger.debug(f"Remote file, file_id={filename}")
|
||||||
if not file_ext:
|
if not file_ext:
|
||||||
|
|||||||
@@ -297,9 +297,12 @@ async def run_pilot_extraction(
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
_,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_edges,
|
entity_edges,
|
||||||
|
_,
|
||||||
|
_
|
||||||
) = extraction_result
|
) = extraction_result
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|||||||
@@ -1888,7 +1888,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
|||||||
"Chunk": ["content", "created_at"],
|
"Chunk": ["content", "created_at"],
|
||||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
"MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
|
||||||
|
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取该节点类型的白名单字段
|
# 获取该节点类型的白名单字段
|
||||||
|
|||||||
@@ -1080,14 +1080,12 @@ def write_message_task(
|
|||||||
config_id: str | int,
|
config_id: str | int,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str,
|
user_rag_memory_id: str,
|
||||||
file_messages: list[dict] | None,
|
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
file_messages: Files to write
|
|
||||||
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
user_rag_memory_id: User RAG memory ID
|
user_rag_memory_id: User RAG memory ID
|
||||||
@@ -1099,9 +1097,6 @@ def write_message_task(
|
|||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
if file_messages is None:
|
|
||||||
file_messages = []
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||||
@@ -1146,7 +1141,7 @@ def write_message_task(
|
|||||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
|
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||||
user_rag_memory_id, language)
|
user_rag_memory_id, language)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
@@ -2617,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
name="app.tasks.write_perceptual_memory",
|
|
||||||
bind=True,
|
|
||||||
ignore_result=True,
|
|
||||||
max_retries=0,
|
|
||||||
acks_late=False,
|
|
||||||
time_limit=3600,
|
|
||||||
soft_time_limit=3300,
|
|
||||||
)
|
|
||||||
def write_perceptual_memory(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
model_api_config: dict,
|
|
||||||
file_type: str,
|
|
||||||
file_url: str,
|
|
||||||
file_message: dict
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
|
||||||
|
|
||||||
This task generates or updates the user's perceptual memory
|
|
||||||
in the backend databases. It is intended to be executed asynchronously
|
|
||||||
via Celery.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
|
||||||
model_api_config (ModelInfo): API configuration for the model
|
|
||||||
used to generate perceptual memory.
|
|
||||||
file_type (str): The file type
|
|
||||||
file_url (url): The url of file
|
|
||||||
file_message (dict): The file message containing details about the file
|
|
||||||
to be processed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
|
||||||
set_asyncio_event_loop()
|
|
||||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
|
||||||
model_info = ModelInfo(**model_api_config)
|
|
||||||
with get_db_context() as db:
|
|
||||||
memory_perceptual_service = MemoryPerceptualService(db)
|
|
||||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
|
||||||
end_user_id,
|
|
||||||
model_info,
|
|
||||||
file_type,
|
|
||||||
file_url,
|
|
||||||
file_message,
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 社区聚类补全任务(触发型)
|
# 社区聚类补全任务(触发型)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user