feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -13,28 +14,31 @@ from dotenv import load_dotenv
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.models import MemoryPerceptualModel
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
||||
add_perceptual_dialogue_edges
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
file_content: list[MemoryPerceptualModel],
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +47,12 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
file_content: mutilmodal message list
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -99,14 +106,14 @@ async def write(
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
@@ -173,7 +180,8 @@ async def write(
|
||||
schedule_clustering_after_write(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
embedding_model_id=str(
|
||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -208,9 +216,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -223,6 +230,34 @@ async def write(
|
||||
finally:
|
||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Step 5: Save perceptual memory to Neo4j
|
||||
step_start = time.time()
|
||||
if file_content:
|
||||
try:
|
||||
pc_connector = Neo4jConnector()
|
||||
try:
|
||||
created_ids = await add_perceptual_nodes(
|
||||
perceptuals=file_content,
|
||||
connector=pc_connector,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
# 如果有 ref_id,建立感知记忆与对话的关联
|
||||
if ref_id and created_ids:
|
||||
await add_perceptual_dialogue_edges(
|
||||
perceptuals=file_content,
|
||||
dialog_id=ref_id,
|
||||
connector=pc_connector,
|
||||
)
|
||||
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
||||
finally:
|
||||
try:
|
||||
await pc_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
||||
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
@@ -251,4 +286,4 @@ async def write(
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -553,3 +553,21 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class MutlimodalNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
message_id: ID of the message
|
||||
metadata: Additional message metadata
|
||||
embedding: Optional embedding vector for the message
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
message_id: str = Field(..., description="ID of the message")
|
||||
summary: str = Field(..., description="The text content of the message")
|
||||
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
||||
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
||||
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
||||
|
||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -374,7 +374,9 @@ class VariablePool:
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||
if variable_struct is None:
|
||||
return False
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
|
||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
||||
async def process_message(
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
|
||||
@@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode):
|
||||
write_message_task.delay(
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
file_messages=multimodal_memories,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
|
||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
|
||||
@@ -9,21 +9,22 @@ Classes:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取数据库专用日志器
|
||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -491,7 +492,10 @@ class MemoryConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
||||
def get_config_with_workspace(
|
||||
db: Session,
|
||||
config_id: uuid.UUID | int | str
|
||||
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||
"""Get memory config and its associated workspace information
|
||||
|
||||
Args:
|
||||
@@ -506,8 +510,6 @@ class MemoryConfigRepository:
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
@@ -594,7 +596,7 @@ class MemoryConfigRepository:
|
||||
|
||||
db_logger.debug(
|
||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
return config, workspace
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
@@ -630,7 +632,7 @@ class MemoryConfigRepository:
|
||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
@@ -694,7 +696,7 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 优先查找显式标记为默认的配置
|
||||
stmt = (
|
||||
@@ -706,13 +708,13 @@ class MemoryConfigRepository:
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
|
||||
return config
|
||||
|
||||
|
||||
# 回退:获取最早创建的活跃配置
|
||||
stmt = (
|
||||
select(MemoryConfig)
|
||||
@@ -723,25 +725,25 @@ class MemoryConfigRepository:
|
||||
.order_by(MemoryConfig.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
|
||||
else:
|
||||
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_with_fallback(
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[MemoryConfig]:
|
||||
"""获取记忆配置,支持回退到工作空间默认配置
|
||||
|
||||
@@ -756,19 +758,18 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
if not config_id:
|
||||
db_logger.debug("config_id 为空,使用工作空间默认配置")
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
|
||||
config = db.get(MemoryConfig, config_id)
|
||||
|
||||
|
||||
if config:
|
||||
return config
|
||||
|
||||
|
||||
db_logger.warning(
|
||||
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
@@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
@@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
@@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
||||
List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
@@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_NODE_SAVE,
|
||||
summaries=flattened
|
||||
@@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_nodes(
|
||||
perceptuals: list,
|
||||
connector: Neo4jConnector,
|
||||
embedder_client=None,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add perceptual memory nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
||||
connector: Neo4j connector instance
|
||||
embedder_client: Optional embedder client for generating summary embeddings
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not perceptuals:
|
||||
print("No perceptual nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
flattened = []
|
||||
for p in perceptuals:
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
# 生成 summary embedding(如果有 embedder_client)
|
||||
summary_embedding = None
|
||||
if embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
flattened.append({
|
||||
"id": str(p.id),
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"perceptual_type": p.perceptual_type,
|
||||
"file_path": p.file_path or "",
|
||||
"file_name": p.file_name or "",
|
||||
"file_ext": p.file_ext or "",
|
||||
"summary": p.summary or "",
|
||||
"keywords": content_meta.get("keywords", []),
|
||||
"topic": content_meta.get("topic", ""),
|
||||
"domain": content_meta.get("domain", ""),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
"summary_embedding": summary_embedding,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_NODE_SAVE,
|
||||
perceptuals=flattened,
|
||||
)
|
||||
created_uuids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_dialogue_edges(
|
||||
perceptuals: list,
|
||||
dialog_id: str,
|
||||
connector: Neo4jConnector,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add edges between Perceptual nodes and Dialogue nodes.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects
|
||||
dialog_id: The dialogue ID (or ref_id) to link to
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge element IDs or None if failed
|
||||
"""
|
||||
if not perceptuals or not dialog_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
edges = []
|
||||
for p in perceptuals:
|
||||
edges.append({
|
||||
"perceptual_id": str(p.id),
|
||||
"dialog_id": dialog_id,
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
||||
edges=edges,
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
||||
return created_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
||||
return None
|
||||
|
||||
@@ -1323,3 +1323,36 @@ RETURN s.statement AS statement,
|
||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 感知记忆节点保存
|
||||
PERCEPTUAL_NODE_SAVE = """
|
||||
UNWIND $perceptuals AS p
|
||||
MERGE (n:Perceptual {id: p.id})
|
||||
SET n += {
|
||||
id: p.id,
|
||||
end_user_id: p.end_user_id,
|
||||
perceptual_type: p.perceptual_type,
|
||||
file_path: p.file_path,
|
||||
file_name: p.file_name,
|
||||
file_ext: p.file_ext,
|
||||
summary: p.summary,
|
||||
keywords: p.keywords,
|
||||
topic: p.topic,
|
||||
domain: p.domain,
|
||||
created_at: p.created_at,
|
||||
summary_embedding: p.summary_embedding
|
||||
}
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
# 感知记忆与对话的关联边
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
||||
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
||||
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
||||
SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
||||
|
||||
rerank_model_id: Optional[UUID] = None
|
||||
rerank_model_name: Optional[str] = None
|
||||
video_model_id: Optional[UUID] = None
|
||||
video_model_name: Optional[str] = None
|
||||
vision_model_id: Optional[UUID] = None
|
||||
vision_model_name: Optional[str] = None
|
||||
audio_model_id: Optional[UUID] = None
|
||||
audio_model_name: Optional[str] = None
|
||||
|
||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -141,7 +141,7 @@ class AppChatService:
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -339,7 +339,7 @@ class AppChatService:
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
|
||||
@@ -600,7 +600,7 @@ class AgentRunService:
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -836,7 +836,7 @@ class AgentRunService:
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
|
||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.cache import InterestMemoryCache
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import (
|
||||
merge_multiple_search_results,
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
@@ -271,6 +274,7 @@ class MemoryAgentService:
|
||||
self,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
file_messages: list[dict],
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
@@ -283,6 +287,7 @@ class MemoryAgentService:
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
messages: Message to write
|
||||
files: Files to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -342,48 +347,52 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
perceptual_serivce = MemoryPerceptualService(db)
|
||||
file_content = []
|
||||
for message in file_messages:
|
||||
for file in message["files"]:
|
||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
file=FileInput(**file)
|
||||
)
|
||||
file_content.append(file_object)
|
||||
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return "success"
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
print(100 * '-')
|
||||
print(langchain_messages)
|
||||
print(100 * '-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config,
|
||||
"language": language
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
file_content=file_content,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
)
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id, lang
|
||||
)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
return self.writer_messages_deal(
|
||||
"success",
|
||||
start_time,
|
||||
end_user_id,
|
||||
config_id,
|
||||
message_text,
|
||||
{
|
||||
"status": "success",
|
||||
"data": messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
|
||||
@@ -28,7 +28,7 @@ class MemoryAPIService:
|
||||
2. Maps end_user_id to end_user_id for memory operations
|
||||
3. Delegates to MemoryAgentService for actual memory read/write operations
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize MemoryAPIService.
|
||||
|
||||
@@ -36,11 +36,11 @@ class MemoryAPIService:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
|
||||
def validate_end_user(
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
) -> EndUser:
|
||||
"""Validate that end_user exists and belongs to the workspace.
|
||||
|
||||
@@ -56,7 +56,7 @@ class MemoryAPIService:
|
||||
BusinessException: If end_user not in authorized workspace
|
||||
"""
|
||||
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Query end_user by ID
|
||||
try:
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
@@ -66,7 +66,7 @@ class MemoryAPIService:
|
||||
message=f"Invalid end_user_id format: {end_user_id}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
|
||||
|
||||
if not end_user:
|
||||
@@ -75,13 +75,13 @@ class MemoryAPIService:
|
||||
resource_type="EndUser",
|
||||
resource_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
# Verify end_user belongs to the workspace via App relationship
|
||||
app = self.db.query(App).filter(
|
||||
App.id == end_user.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
|
||||
if not app:
|
||||
logger.warning(f"App not found for end_user: {end_user_id}")
|
||||
# raise ResourceNotFoundException(
|
||||
@@ -99,7 +99,7 @@ class MemoryAPIService:
|
||||
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
||||
# code=BizCode.FORBIDDEN
|
||||
# )
|
||||
|
||||
|
||||
logger.info(f"End user {end_user_id} validated successfully")
|
||||
return end_user
|
||||
|
||||
@@ -125,13 +125,14 @@ class MemoryAPIService:
|
||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||
|
||||
async def write_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
files: Optional[list]=None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Write memory with validation.
|
||||
|
||||
@@ -153,14 +154,16 @@ class MemoryAPIService:
|
||||
ResourceNotFoundException: If end_user not found
|
||||
BusinessException: If end_user not in authorized workspace or write fails
|
||||
"""
|
||||
if files is None:
|
||||
files = list()
|
||||
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
|
||||
# Update end user's memory_config_id
|
||||
self._update_end_user_config(end_user_id, config_id)
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||
@@ -171,11 +174,12 @@ class MemoryAPIService:
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id or ""
|
||||
user_rag_memory_id=user_rag_memory_id or "",
|
||||
files=files
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
|
||||
|
||||
# result may be a string "success" or a dict with a "status" key
|
||||
# Preserve the full dict so callers don't silently lose extra fields
|
||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||
@@ -189,7 +193,7 @@ class MemoryAPIService:
|
||||
"status": result if isinstance(result, str) else "success",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||
raise BusinessException(
|
||||
@@ -204,16 +208,16 @@ class MemoryAPIService:
|
||||
message=f"Memory write failed: {str(e)}",
|
||||
code=BizCode.MEMORY_WRITE_FAILED
|
||||
)
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Read memory with validation.
|
||||
|
||||
@@ -237,14 +241,13 @@ class MemoryAPIService:
|
||||
BusinessException: If end_user not in authorized workspace or read fails
|
||||
"""
|
||||
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
|
||||
# Update end user's memory_config_id
|
||||
self._update_end_user_config(end_user_id, config_id)
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
result = await MemoryAgentService().read_memory(
|
||||
@@ -257,15 +260,15 @@ class MemoryAPIService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id or ""
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {end_user_id}")
|
||||
|
||||
|
||||
return {
|
||||
"answer": result.get("answer", ""),
|
||||
"intermediate_outputs": result.get("intermediate_outputs", []),
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||
raise BusinessException(
|
||||
@@ -282,8 +285,8 @@ class MemoryAPIService:
|
||||
)
|
||||
|
||||
def list_memory_configs(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Dict[str, Any]:
|
||||
"""List all memory configs for a workspace.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||
if isinstance(config_id, uuid.UUID):
|
||||
return config_id
|
||||
|
||||
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
config_id_stripped = config_id.strip()
|
||||
|
||||
|
||||
# Try parsing as UUID first
|
||||
try:
|
||||
return uuid.UUID(config_id_stripped)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# Fall back to integer parsing
|
||||
try:
|
||||
parsed_id = int(config_id_stripped)
|
||||
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
|
||||
if db is not None:
|
||||
# 查询 user_id 匹配的记录
|
||||
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
|
||||
result = db.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database with optional fallback.
|
||||
@@ -194,14 +194,14 @@ class MemoryConfigService:
|
||||
try:
|
||||
# Use get_config_with_fallback if workspace_id is provided
|
||||
memory_config = None
|
||||
validated_config_id = None
|
||||
if workspace_id:
|
||||
validated_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
except Exception:
|
||||
validated_config_id = None
|
||||
|
||||
|
||||
memory_config = self.get_config_with_fallback(
|
||||
memory_config_id=validated_config_id,
|
||||
workspace_id=workspace_id
|
||||
@@ -210,7 +210,7 @@ class MemoryConfigService:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
|
||||
|
||||
|
||||
if not memory_config:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -233,7 +233,7 @@ class MemoryConfigService:
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
|
||||
|
||||
if not result:
|
||||
raise ConfigurationError(
|
||||
f"Workspace not found for config {memory_config.config_id}"
|
||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
||||
|
||||
# Helper function to validate model with workspace fallback
|
||||
def _validate_model_with_fallback(
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
) -> tuple:
|
||||
"""Validate model ID, falling back to workspace default if invalid.
|
||||
|
||||
@@ -275,7 +275,7 @@ class MemoryConfigService:
|
||||
logger.warning(
|
||||
f"{model_type} model validation failed, trying workspace default: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Fallback to workspace default
|
||||
if workspace_default:
|
||||
try:
|
||||
@@ -297,7 +297,7 @@ class MemoryConfigService:
|
||||
logger.error(f"Workspace default {model_type} model also invalid: {e}")
|
||||
if required:
|
||||
raise
|
||||
|
||||
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model is required but not configured",
|
||||
@@ -306,7 +306,7 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id
|
||||
)
|
||||
|
||||
|
||||
return None, None
|
||||
|
||||
# Step 2: Validate embedding model with workspace fallback
|
||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
||||
if memory_config.rerank_id or workspace.rerank:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||
memory_config.vision_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||
memory_config.audio_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
video_uuid, video_name = validate_and_resolve_model_id(
|
||||
memory_config.video_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
video_model_id=video_uuid,
|
||||
video_model_name=video_name,
|
||||
vision_model_id=vision_uuid,
|
||||
vision_model_name=vision_name,
|
||||
audio_model_id=audio_uuid,
|
||||
audio_model_name=audio_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
||||
reflexion_baseline=memory_config.baseline or "Time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
enable_llm_dedup_blockwise=bool(
|
||||
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(
|
||||
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
statement_granularity=int(
|
||||
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(
|
||||
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(
|
||||
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_enabled=bool(
|
||||
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
pruning_threshold=float(
|
||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
@@ -448,9 +490,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -481,9 +523,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -571,25 +613,25 @@ class MemoryConfigService:
|
||||
"""
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
if not memory_config.scene_id:
|
||||
logger.debug("No scene_id configured, skipping ontology type fetch")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
ontology_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
|
||||
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
return None
|
||||
|
||||
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
return ontology_types
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
||||
return None
|
||||
|
||||
def get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
self,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get workspace default memory config.
|
||||
|
||||
@@ -613,19 +655,19 @@ class MemoryConfigService:
|
||||
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
||||
"""
|
||||
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
|
||||
|
||||
|
||||
if not config:
|
||||
logger.warning(
|
||||
"No active memory config found for workspace fallback",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def get_config_with_fallback(
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get memory config with fallback to workspace default.
|
||||
|
||||
@@ -644,13 +686,13 @@ class MemoryConfigService:
|
||||
"No memory config ID provided, using workspace default",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
config = MemoryConfigRepository.get_with_fallback(
|
||||
self.db,
|
||||
memory_config_id,
|
||||
workspace_id
|
||||
)
|
||||
|
||||
|
||||
if not config and memory_config_id:
|
||||
logger.warning(
|
||||
"Memory config not found, falling back to workspace default",
|
||||
@@ -659,13 +701,13 @@ class MemoryConfigService:
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def delete_config(
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
) -> dict:
|
||||
"""Delete memory config with protection against in-use configs.
|
||||
|
||||
@@ -687,7 +729,7 @@ class MemoryConfigService:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
|
||||
# 处理旧格式 int 类型的 config_id
|
||||
if isinstance(config_id, int):
|
||||
logger.warning(
|
||||
@@ -699,11 +741,11 @@ class MemoryConfigService:
|
||||
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
|
||||
"legacy_int_id": config_id
|
||||
}
|
||||
|
||||
|
||||
config = self.db.get(MemoryConfigModel, config_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
||||
|
||||
|
||||
# Check if this is the default config - default configs cannot be deleted
|
||||
if config.is_default:
|
||||
logger.warning(
|
||||
@@ -715,11 +757,11 @@ class MemoryConfigService:
|
||||
"message": "默认配置不允许删除",
|
||||
"is_default": True
|
||||
}
|
||||
|
||||
|
||||
# Use repository to count connected end users
|
||||
end_user_repo = EndUserRepository(self.db)
|
||||
connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
||||
|
||||
|
||||
if connected_count > 0 and not force:
|
||||
logger.warning(
|
||||
"Attempted to delete memory config with connected end users",
|
||||
@@ -728,18 +770,18 @@ class MemoryConfigService:
|
||||
"connected_count": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
|
||||
"connected_count": connected_count,
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Force delete: use repository to clear end user references first
|
||||
if connected_count > 0 and force:
|
||||
cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
||||
|
||||
|
||||
logger.warning(
|
||||
"Force deleting memory config, clearing end user references",
|
||||
extra={
|
||||
@@ -747,11 +789,11 @@ class MemoryConfigService:
|
||||
"cleared_end_users": cleared_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
self.db.delete(config)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
logger.info(
|
||||
"Memory config deleted",
|
||||
extra={
|
||||
@@ -760,16 +802,16 @@ class MemoryConfigService:
|
||||
"affected_users": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "记忆配置删除成功",
|
||||
"affected_users": connected_count
|
||||
}
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
|
||||
|
||||
# Handle foreign key violation gracefully
|
||||
error_str = str(e.orig) if e.orig else str(e)
|
||||
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
|
||||
@@ -785,7 +827,7 @@ class MemoryConfigService:
|
||||
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Re-raise other integrity errors
|
||||
logger.error(
|
||||
"Delete failed due to integrity error",
|
||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||
|
||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_agent(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Agent 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_workflow(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Workflow 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -905,14 +947,14 @@ class MemoryConfigService:
|
||||
- is_legacy_int: 是否检测到旧格式 int 数据
|
||||
"""
|
||||
nodes = config.get("nodes", [])
|
||||
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type", "")
|
||||
|
||||
|
||||
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
|
||||
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
|
||||
config_id = node.get("config", {}).get("config_id")
|
||||
|
||||
|
||||
if config_id:
|
||||
try:
|
||||
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
|
||||
@@ -937,6 +979,6 @@ class MemoryConfigService:
|
||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||
f"node_type={node_type}, error={str(e)}"
|
||||
)
|
||||
|
||||
|
||||
logger.debug("工作流配置中未找到记忆节点")
|
||||
return None, False
|
||||
|
||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import FileMetadata
|
||||
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas import FileType, FileInput
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
def _get_mutlimodal_client(
|
||||
self,
|
||||
file_type: FileType,
|
||||
config: MemoryConfig
|
||||
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||
model_config = None
|
||||
if file_type == FileType.AUDIO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.audio_model_id
|
||||
)
|
||||
elif file_type == FileType.VIDEO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.video_model_id
|
||||
)
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.llm_model_id
|
||||
)
|
||||
elif file_type == FileType.IMAGE:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.vision_model_id
|
||||
)
|
||||
llm = None
|
||||
if model_config:
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
)
|
||||
return llm, model_config
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
api_base=model_config.api_base,
|
||||
is_omni=model_config.is_omni,
|
||||
capability=model_config.capability,
|
||||
model_type=ModelType.LLM
|
||||
))
|
||||
file_message = await multimodel_service.process_files(
|
||||
files=[file]
|
||||
)
|
||||
if file_message:
|
||||
file_message = file_message[0]
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
final_output += msg.get("text", "")
|
||||
elif isinstance(msg, str):
|
||||
final_output += msg
|
||||
elif isinstance(content, dict):
|
||||
final_output += content.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
final_output = content
|
||||
else:
|
||||
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||
content = json_repair.repair_json(final_output, return_objects=True)
|
||||
path = urlparse(file.url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
|
||||
except ValueError:
|
||||
business_logger.debug(f"Remote file, file_id={filename}")
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
if file.type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
elif file.type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
elif file.type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene", [])
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
elif file.type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count", 0),
|
||||
"title": content.get("title", ""),
|
||||
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count", 0)
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||
file_path=file.url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary', ""),
|
||||
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
|
||||
@@ -9,14 +9,12 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import base64
|
||||
import csv
|
||||
import io
|
||||
import uuid
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import csv
|
||||
import json
|
||||
|
||||
import PyPDF2
|
||||
import httpx
|
||||
import magic
|
||||
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
is_support, content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
is_support, content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -418,17 +402,6 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
@@ -1080,12 +1080,14 @@ def write_message_task(
|
||||
config_id: str | int,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
file_messages: list[dict] | None,
|
||||
language: str = "zh"
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
file_messages: Files to write
|
||||
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
@@ -1097,6 +1099,8 @@ def write_message_task(
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
if file_messages is None:
|
||||
file_messages = []
|
||||
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
@@ -1142,7 +1146,7 @@ def write_message_task(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||
result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
|
||||
user_rag_memory_id, language)
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user