feat(memory, model): update multi-modal memory write and model list API
- Adjust multi-modal memory write behavior for text and visual data - Mask API keys in model list response to prevent exposure - Add capability-based filtering to the model list API
This commit is contained in:
@@ -42,6 +42,7 @@ def get_model_strategies():
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -74,10 +75,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -19,7 +19,7 @@ from app.services.user_memory_service import (
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
@@ -45,9 +45,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -73,10 +73,10 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
@@ -90,7 +90,7 @@ async def get_user_summary_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -102,7 +102,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -117,10 +117,10 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
@@ -134,7 +134,7 @@ async def generate_cache_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -155,10 +155,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -209,9 +211,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -220,7 +222,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -228,21 +231,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -298,9 +303,9 @@ async def get_graph_data_api(
|
||||
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -334,9 +339,9 @@ async def get_community_graph_data_api(
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
@@ -385,9 +390,9 @@ async def get_end_user_profile(
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
@@ -417,7 +422,7 @@ async def update_end_user_profile(
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
@@ -427,15 +432,18 @@ async def update_end_user_profile(
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
@@ -5,8 +5,8 @@ This module provides the main write function for executing the knowledge extract
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -19,10 +19,8 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.models import MemoryPerceptualModel
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
||||
add_perceptual_dialogue_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -36,7 +34,6 @@ async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
file_content: list[MemoryPerceptualModel],
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
@@ -47,7 +44,6 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
file_content: mutilmodal message list
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
@@ -142,9 +138,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -169,9 +167,11 @@ async def write(
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
@@ -230,34 +230,6 @@ async def write(
|
||||
finally:
|
||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Step 5: Save perceptual memory to Neo4j
|
||||
step_start = time.time()
|
||||
if file_content:
|
||||
try:
|
||||
pc_connector = Neo4jConnector()
|
||||
try:
|
||||
created_ids = await add_perceptual_nodes(
|
||||
perceptuals=file_content,
|
||||
connector=pc_connector,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
# 如果有 ref_id,建立感知记忆与对话的关联
|
||||
if ref_id and created_ids:
|
||||
await add_perceptual_dialogue_edges(
|
||||
perceptuals=file_content,
|
||||
dialog_id=ref_id,
|
||||
connector=pc_connector,
|
||||
)
|
||||
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
||||
finally:
|
||||
try:
|
||||
await pc_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
||||
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class PerceptualEdge(Edge):
|
||||
"""Edge connecting perceptual nodes to their source chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
@@ -555,19 +561,16 @@ class MemorySummaryNode(Node):
|
||||
)
|
||||
|
||||
|
||||
class MutlimodalNode(Node):
|
||||
class PerceptualNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
message_id: ID of the message
|
||||
metadata: Additional message metadata
|
||||
embedding: Optional embedding vector for the message
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
message_id: str = Field(..., description="ID of the message")
|
||||
summary: str = Field(..., description="The text content of the message")
|
||||
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
||||
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
||||
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
||||
perceptual_type: int
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_ext: str
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
file_type: str
|
||||
summary_embedding: list[float] | None
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -31,7 +31,9 @@ from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode
|
||||
StatementNode,
|
||||
PerceptualEdge,
|
||||
PerceptualNode
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -170,9 +172,11 @@ class ExtractionOrchestrator:
|
||||
list[ChunkNode],
|
||||
list[StatementNode],
|
||||
list[ExtractedEntityNode],
|
||||
list[PerceptualNode],
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[PerceptualEdge],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
@@ -259,9 +263,11 @@ class ExtractionOrchestrator:
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges
|
||||
) = await self._create_nodes_and_edges(dialog_data_list)
|
||||
|
||||
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
||||
@@ -275,7 +281,16 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||
|
||||
result = await self._run_dedup_and_write_summary(
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
dialog_data_list,
|
||||
) = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
@@ -287,7 +302,18 @@ class ExtractionOrchestrator:
|
||||
)
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return result
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges,
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
||||
@@ -1000,9 +1026,11 @@ class ExtractionOrchestrator:
|
||||
List[ChunkNode],
|
||||
List[StatementNode],
|
||||
List[ExtractedEntityNode],
|
||||
List[PerceptualNode],
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge]
|
||||
List[EntityEntityEdge],
|
||||
List[PerceptualEdge]
|
||||
]:
|
||||
"""
|
||||
创建图数据库节点和边
|
||||
@@ -1026,6 +1054,8 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edges = []
|
||||
statement_entity_edges = []
|
||||
entity_entity_edges = []
|
||||
perceptual_nodes = []
|
||||
perceptual_edges = []
|
||||
|
||||
# 用于去重的集合
|
||||
entity_id_set = set()
|
||||
@@ -1069,6 +1099,46 @@ class ExtractionOrchestrator:
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
logger.error(f"chunk file: {chunk.files}")
|
||||
|
||||
for p, file_type in chunk.files:
|
||||
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
# 生成 summary embedding(如果有 embedder_client)
|
||||
summary_embedding = None
|
||||
if self.embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
perceptual = PerceptualNode(
|
||||
name=f"Perceptual_{p.id}",
|
||||
**{
|
||||
"id": str(p.id),
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"perceptual_type": p.perceptual_type,
|
||||
"file_path": p.file_path or "",
|
||||
"file_name": p.file_name or "",
|
||||
"file_ext": p.file_ext or "",
|
||||
"summary": p.summary or "",
|
||||
"keywords": content_meta.get("keywords", []),
|
||||
"topic": content_meta.get("topic", ""),
|
||||
"domain": content_meta.get("domain", ""),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
"file_type": file_type,
|
||||
"summary_embedding": summary_embedding,
|
||||
})
|
||||
perceptual_nodes.append(perceptual)
|
||||
perceptual_edges.append(PerceptualEdge(
|
||||
source=perceptual.id,
|
||||
target=chunk.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
))
|
||||
|
||||
# 处理每个陈述句
|
||||
for statement in chunk.statements:
|
||||
@@ -1248,9 +1318,11 @@ class ExtractionOrchestrator:
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges
|
||||
)
|
||||
|
||||
async def _run_dedup_and_write_summary(
|
||||
|
||||
@@ -72,7 +72,6 @@ class MemoryWriteNode(BaseNode):
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
messages = []
|
||||
multimodal_memories = []
|
||||
if self.typed_config.message:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -104,19 +103,15 @@ class MemoryWriteNode(BaseNode):
|
||||
url=file_instence.value.url,
|
||||
file_type=file_instence.value.origin_file_type
|
||||
).model_dump())
|
||||
multimodal_memories.append({
|
||||
"role": message.role,
|
||||
"files": file_info
|
||||
})
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(content, variable_pool)
|
||||
"content": self._render_template(content, variable_pool),
|
||||
"files": file_info
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
file_messages=multimodal_memories,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
|
||||
@@ -2,10 +2,11 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
|
||||
@@ -408,6 +408,9 @@ class MemoryConfigRepository:
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
"rerank_id": db_config.rerank_id,
|
||||
"vision_id": db_config.vision_id,
|
||||
"audio_id": db_config.audio_id,
|
||||
"video_id": db_config.video_id,
|
||||
"enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": db_config.enable_llm_disambiguation,
|
||||
"deep_retrieval": db_config.deep_retrieval,
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
from sqlalchemy import and_, or_, func, desc, select
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -137,6 +138,9 @@ class ModelConfigRepository:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.capability:
|
||||
filters.append(ModelConfig.capability.contains(query.capability))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
||||
MEMORY_SUMMARY_NODE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
"""Delete all nodes in the database."""
|
||||
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
|
||||
@@ -25,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not dialogues:
|
||||
print("No dialogues to save")
|
||||
logger.info("No dialogues to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -50,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating dialogue nodes: {e}")
|
||||
logger.info(f"Error creating dialogue nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -69,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not statements:
|
||||
print("No statements to save")
|
||||
logger.info("No statements to save")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -122,11 +125,11 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
logger.info(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
List of created chunk UUIDs or None if failed
|
||||
"""
|
||||
if not chunks:
|
||||
print("No chunk nodes to add")
|
||||
logger.info("No chunk nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
)
|
||||
|
||||
created_uuids = [record["uuid"] for record in result]
|
||||
print(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating chunk nodes: {e}")
|
||||
logger.info(f"Error creating chunk nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
||||
List[str]]:
|
||||
async def add_memory_summary_nodes(
|
||||
summaries: List[MemorySummaryNode],
|
||||
connector: Neo4jConnector
|
||||
) -> Optional[List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
@@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
List of created summary node ids or None if failed
|
||||
"""
|
||||
if not summaries:
|
||||
print("No memory summary nodes to add")
|
||||
logger.info("No memory summary nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -220,110 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
return created_ids
|
||||
except Exception as e:
|
||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_nodes(
|
||||
perceptuals: list,
|
||||
connector: Neo4jConnector,
|
||||
embedder_client=None,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add perceptual memory nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
||||
connector: Neo4j connector instance
|
||||
embedder_client: Optional embedder client for generating summary embeddings
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not perceptuals:
|
||||
print("No perceptual nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
flattened = []
|
||||
for p in perceptuals:
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
# 生成 summary embedding(如果有 embedder_client)
|
||||
summary_embedding = None
|
||||
if embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
flattened.append({
|
||||
"id": str(p.id),
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"perceptual_type": p.perceptual_type,
|
||||
"file_path": p.file_path or "",
|
||||
"file_name": p.file_name or "",
|
||||
"file_ext": p.file_ext or "",
|
||||
"summary": p.summary or "",
|
||||
"keywords": content_meta.get("keywords", []),
|
||||
"topic": content_meta.get("topic", ""),
|
||||
"domain": content_meta.get("domain", ""),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
"summary_embedding": summary_embedding,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_NODE_SAVE,
|
||||
perceptuals=flattened,
|
||||
)
|
||||
created_uuids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_dialogue_edges(
|
||||
perceptuals: list,
|
||||
dialog_id: str,
|
||||
connector: Neo4jConnector,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add edges between Perceptual nodes and Dialogue nodes.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects
|
||||
dialog_id: The dialogue ID (or ref_id) to link to
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge element IDs or None if failed
|
||||
"""
|
||||
if not perceptuals or not dialog_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
edges = []
|
||||
for p in perceptuals:
|
||||
edges.append({
|
||||
"perceptual_id": str(p.id),
|
||||
"dialog_id": dialog_id,
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
||||
edges=edges,
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
||||
return created_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
||||
logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
@@ -1003,60 +1003,69 @@ RETURN DISTINCT
|
||||
"""
|
||||
|
||||
Graph_Node_query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
UNION ALL
|
||||
MATCH (n:Perceptual)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
4 AS priority
|
||||
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||
@@ -1340,19 +1349,19 @@ SET n += {
|
||||
topic: p.topic,
|
||||
domain: p.domain,
|
||||
created_at: p.created_at,
|
||||
file_type: p.file_type,
|
||||
summary_embedding: p.summary_embedding
|
||||
}
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
# 感知记忆与对话的关联边
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
||||
PERCEPTUAL_CHUNK_EDGE_SAVE = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
||||
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
||||
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
||||
SET r.end_user_id = edge.end_user_id,
|
||||
MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id})
|
||||
MERGE (c)-[r:HAS_PERCEPTUAL]->(p)
|
||||
ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
|
||||
StatementNode,
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
PerceptualNode,
|
||||
PerceptualEdge,
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save entities and their relationships using graph models"""
|
||||
all_entities = [entity.model_dump() for entity in entity_nodes]
|
||||
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
|
||||
|
||||
|
||||
async def save_chunk_nodes(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
chunk_nodes: List[ChunkNode],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save chunk nodes using graph models"""
|
||||
if not chunk_nodes:
|
||||
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
|
||||
|
||||
|
||||
async def save_statement_chunk_edges(
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-chunk edges using graph models"""
|
||||
if not statement_chunk_edges:
|
||||
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
|
||||
|
||||
|
||||
async def save_statement_entity_edges(
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""Save statement-entity edges using graph models"""
|
||||
if not statement_entity_edges:
|
||||
@@ -142,7 +147,7 @@ async def save_statement_entity_edges(
|
||||
if all_se_edges:
|
||||
try:
|
||||
await connector.execute_query(
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
STATEMENT_ENTITY_EDGE_SAVE,
|
||||
relationships=all_se_edges
|
||||
)
|
||||
except Exception:
|
||||
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
perceptual_nodes: List[PerceptualNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
perceptual_edges: List[PerceptualEdge],
|
||||
connector: Neo4jConnector,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
statement_nodes: List of StatementNode objects to save
|
||||
entity_nodes: List of ExtractedEntityNode objects to save
|
||||
perceptual_nodes: List of PerceptualNode objects to save
|
||||
entity_edges: List of EntityEntityEdge objects to save
|
||||
statement_chunk_edges: List of StatementChunkEdge objects to save
|
||||
statement_entity_edges: List of StatementEntityEdge objects to save
|
||||
perceptual_edges: List of PerceptualEdge objects to save
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||
dialogue_uuids = [record["uuid"] async for record in result]
|
||||
results['dialogues'] = dialogue_uuids
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
|
||||
# 2. Save all chunk nodes in batch
|
||||
if chunk_nodes:
|
||||
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['chunks'] = chunk_uuids
|
||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
|
||||
if perceptual_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
|
||||
perceptual_data = [node.model_dump() for node in perceptual_nodes]
|
||||
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
|
||||
perceptual_uuids = [record["uuid"] async for record in result]
|
||||
results["perceptuals"] = perceptual_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
|
||||
|
||||
# 3. Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
results['statement_entity_edges'] = se_uuids
|
||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||
|
||||
if perceptual_edges:
|
||||
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
|
||||
perceptual_edge_data = []
|
||||
for edge in perceptual_edges:
|
||||
print(edge.source, edge.target)
|
||||
perceptual_edge_data.append({
|
||||
"perceptual_id": edge.source,
|
||||
"chunk_id": edge.target,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
})
|
||||
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
|
||||
perceptual_edges_uuids = [record["uuid"] async for record in result]
|
||||
results['perceptual_chunk_edges'] = perceptual_edges_uuids
|
||||
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
|
||||
|
||||
return results
|
||||
|
||||
try:
|
||||
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
entity_nodes: List,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
@@ -325,14 +358,15 @@ def schedule_clustering_after_write(
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id))
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
|
||||
@@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase):
|
||||
updated_at: datetime.datetime
|
||||
api_keys: List["ModelApiKey"] = []
|
||||
|
||||
@staticmethod
|
||||
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||||
if not key or len(key) <= prefix + suffix:
|
||||
return "*" * len(key)
|
||||
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||||
|
||||
@field_validator("api_keys", mode="after")
|
||||
@classmethod
|
||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||
@@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase):
|
||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("api_keys", when_used="json")
|
||||
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||||
result = []
|
||||
for api_key in api_keys:
|
||||
data = api_key.model_dump()
|
||||
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
||||
self.model_config_ids = [
|
||||
mc.id for mc in self.model_configs
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
]
|
||||
# 情况2:字典列表
|
||||
elif isinstance(self.model_configs, list):
|
||||
self.model_config_ids = [
|
||||
mc['id'] if isinstance(mc, dict) else mc.id
|
||||
for mc in self.model_configs
|
||||
if ((isinstance(mc, dict)
|
||||
and 'id' in mc
|
||||
if ((isinstance(mc, dict)
|
||||
and 'id' in mc
|
||||
and not mc.get('is_composite', False)
|
||||
and mc.get('name') == self.model_name) or
|
||||
(hasattr(mc, 'id')
|
||||
and mc.get('name') == self.model_name) or
|
||||
(hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name))
|
||||
]
|
||||
@@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
validate_assignment=True # 确保赋值触发校验
|
||||
)
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
@@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel):
|
||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelMarketplace(BaseModel):
|
||||
"""模型广场响应Schema"""
|
||||
llm_models: List[ModelConfig] = []
|
||||
@@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel):
|
||||
class ModelBase(BaseModel):
|
||||
"""基础模型Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
type: str
|
||||
@@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel):
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
@@ -336,4 +353,3 @@ class ModelInfo(BaseModel):
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -274,7 +274,6 @@ class MemoryAgentService:
|
||||
self,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
file_messages: list[dict],
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
@@ -287,7 +286,6 @@ class MemoryAgentService:
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
messages: Message to write
|
||||
files: Files to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -348,15 +346,15 @@ class MemoryAgentService:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
perceptual_serivce = MemoryPerceptualService(db)
|
||||
file_content = []
|
||||
for message in file_messages:
|
||||
for message in messages:
|
||||
message["file_content"] = []
|
||||
for file in message["files"]:
|
||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
file=FileInput(**file)
|
||||
)
|
||||
file_content.append(file_object)
|
||||
message["file_content"].append((file_object, file["type"]))
|
||||
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
try:
|
||||
@@ -368,7 +366,6 @@ class MemoryAgentService:
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
file_content=file_content,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
@@ -380,19 +377,23 @@ class MemoryAgentService:
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
return self.writer_messages_deal(
|
||||
"success",
|
||||
start_time,
|
||||
end_user_id,
|
||||
config_id,
|
||||
message_text,
|
||||
{
|
||||
"status": "success",
|
||||
"data": messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name
|
||||
}
|
||||
)
|
||||
for message in messages:
|
||||
message["file_content"] = [
|
||||
perceptual[0].file_path for perceptual in message["file_content"]
|
||||
]
|
||||
return self.writer_messages_deal(
|
||||
"success",
|
||||
start_time,
|
||||
end_user_id,
|
||||
config_id,
|
||||
message_text,
|
||||
{
|
||||
"status": "success",
|
||||
"data": messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
|
||||
@@ -317,11 +317,11 @@ class MemoryPerceptualService:
|
||||
stmt = select(FileMetadata).where(
|
||||
FileMetadata.id == file_id
|
||||
)
|
||||
file = self.db.execute(stmt).scalar_one_or_none()
|
||||
file_obj = self.db.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if file:
|
||||
filename = file.file_name
|
||||
file_ext = file.file_ext
|
||||
if file_obj:
|
||||
filename = file_obj.file_name
|
||||
file_ext = file_obj.file_ext
|
||||
except ValueError:
|
||||
business_logger.debug(f"Remote file, file_id={filename}")
|
||||
if not file_ext:
|
||||
|
||||
@@ -297,9 +297,12 @@ async def run_pilot_extraction(
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
_,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_,
|
||||
_
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
@@ -1888,7 +1888,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
"Chunk": ["content", "created_at"],
|
||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
|
||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
|
||||
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
|
||||
}
|
||||
|
||||
# 获取该节点类型的白名单字段
|
||||
|
||||
@@ -1080,14 +1080,12 @@ def write_message_task(
|
||||
config_id: str | int,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
file_messages: list[dict] | None,
|
||||
language: str = "zh"
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
file_messages: Files to write
|
||||
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
@@ -1099,9 +1097,6 @@ def write_message_task(
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
if file_messages is None:
|
||||
file_messages = []
|
||||
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||
@@ -1146,7 +1141,7 @@ def write_message_task(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||
user_rag_memory_id, language)
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
@@ -2617,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 社区聚类补全任务(触发型)
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user