feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract
|
|||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import uuid
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -13,28 +14,31 @@ from dotenv import load_dotenv
|
|||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||||
|
memory_summary_generation
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
|
from app.models import MemoryPerceptualModel
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
||||||
|
add_perceptual_dialogue_edges
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
file_content: list[MemoryPerceptualModel],
|
||||||
language: str = "zh",
|
ref_id: str = "",
|
||||||
|
language: str = "zh",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
@@ -43,9 +47,12 @@ async def write(
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
file_content: mutilmodal message list
|
||||||
|
ref_id: Reference ID, defaults to ""
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||||
"""
|
"""
|
||||||
|
if not ref_id:
|
||||||
|
ref_id = uuid.uuid4().hex
|
||||||
# Extract config values
|
# Extract config values
|
||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
@@ -173,7 +180,8 @@ async def write(
|
|||||||
schedule_clustering_after_write(
|
schedule_clustering_after_write(
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
embedding_model_id=str(
|
||||||
|
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -208,9 +216,8 @@ async def write(
|
|||||||
summaries = await memory_summary_generation(
|
summaries = await memory_summary_generation(
|
||||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||||
)
|
)
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
ms_connector = Neo4jConnector()
|
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
finally:
|
finally:
|
||||||
@@ -223,6 +230,34 @@ async def write(
|
|||||||
finally:
|
finally:
|
||||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 5: Save perceptual memory to Neo4j
|
||||||
|
step_start = time.time()
|
||||||
|
if file_content:
|
||||||
|
try:
|
||||||
|
pc_connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
created_ids = await add_perceptual_nodes(
|
||||||
|
perceptuals=file_content,
|
||||||
|
connector=pc_connector,
|
||||||
|
embedder_client=embedder_client,
|
||||||
|
)
|
||||||
|
# 如果有 ref_id,建立感知记忆与对话的关联
|
||||||
|
if ref_id and created_ids:
|
||||||
|
await add_perceptual_dialogue_edges(
|
||||||
|
perceptuals=file_content,
|
||||||
|
dialog_id=ref_id,
|
||||||
|
connector=pc_connector,
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await pc_connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
||||||
|
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# Log total pipeline time
|
# Log total pipeline time
|
||||||
total_time = time.time() - pipeline_start
|
total_time = time.time() - pipeline_start
|
||||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||||
|
|||||||
@@ -553,3 +553,21 @@ class MemorySummaryNode(Node):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MutlimodalNode(Node):
|
||||||
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dialog_id: ID of the parent dialog
|
||||||
|
message_id: ID of the message
|
||||||
|
metadata: Additional message metadata
|
||||||
|
embedding: Optional embedding vector for the message
|
||||||
|
"""
|
||||||
|
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||||
|
message_id: str = Field(..., description="ID of the message")
|
||||||
|
summary: str = Field(..., description="The text content of the message")
|
||||||
|
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
||||||
|
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
||||||
|
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
||||||
|
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
||||||
|
|||||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
|
|
||||||
|
|
||||||
async def dedup_layers_and_merge_and_return(
|
async def dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
pipeline_config: ExtractionPipelineConfig,
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
connector: Optional[Neo4jConnector] = None,
|
connector: Optional[Neo4jConnector] = None,
|
||||||
llm_client = None,
|
llm_client=None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
dict, # 新增:返回去重详情
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两层实体去重与融合:
|
执行两层实体去重与融合:
|
||||||
|
|||||||
@@ -31,11 +31,10 @@ from app.core.memory.models.graph_models import (
|
|||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
StatementChunkEdge,
|
StatementChunkEdge,
|
||||||
StatementEntityEdge,
|
StatementEntityEdge,
|
||||||
StatementNode,
|
StatementNode
|
||||||
)
|
)
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
|
||||||
from app.core.memory.models.variate_config import (
|
from app.core.memory.models.variate_config import (
|
||||||
ExtractionPipelineConfig,
|
ExtractionPipelineConfig,
|
||||||
)
|
)
|
||||||
@@ -46,7 +45,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
|
|||||||
embedding_generation,
|
embedding_generation,
|
||||||
generate_entity_embeddings_from_triplets,
|
generate_entity_embeddings_from_triplets,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入各个提取模块
|
# 导入各个提取模块
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||||
StatementExtractor,
|
StatementExtractor,
|
||||||
@@ -90,16 +88,16 @@ class ExtractionOrchestrator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
embedder_client: OpenAIEmbedderClient,
|
embedder_client: OpenAIEmbedderClient,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config: Optional[ExtractionPipelineConfig] = None,
|
config: Optional[ExtractionPipelineConfig] = None,
|
||||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||||
embedding_id: Optional[str] = None,
|
embedding_id: Optional[str] = None,
|
||||||
ontology_types: Optional[OntologyTypeList] = None,
|
ontology_types: Optional[OntologyTypeList] = None,
|
||||||
enable_general_types: bool = True,
|
enable_general_types: bool = True,
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流水线编排器
|
初始化流水线编排器
|
||||||
@@ -157,19 +155,25 @@ class ExtractionOrchestrator:
|
|||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
config=self.config.statement_extraction,
|
config=self.config.statement_extraction,
|
||||||
)
|
)
|
||||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
|
||||||
|
language=language)
|
||||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||||
|
|
||||||
logger.info("ExtractionOrchestrator 初始化完成")
|
logger.info("ExtractionOrchestrator 初始化完成")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
is_pilot_run: bool = False,
|
is_pilot_run: bool = False,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
运行完整的知识提取流水线(优化版:并行执行)
|
运行完整的知识提取流水线(优化版:并行执行)
|
||||||
@@ -208,7 +212,6 @@ class ExtractionOrchestrator:
|
|||||||
for dialog in dialog_data_list:
|
for dialog in dialog_data_list:
|
||||||
for chunk in dialog.chunks:
|
for chunk in dialog.chunks:
|
||||||
all_statements_list.extend(chunk.statements)
|
all_statements_list.extend(chunk.statements)
|
||||||
len(all_statements_list)
|
|
||||||
|
|
||||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||||
@@ -230,10 +233,6 @@ class ExtractionOrchestrator:
|
|||||||
all_entities_list.extend(triplet_info.entities)
|
all_entities_list.extend(triplet_info.entities)
|
||||||
all_triplets_list.extend(triplet_info.triplets)
|
all_triplets_list.extend(triplet_info.triplets)
|
||||||
|
|
||||||
len(all_entities_list)
|
|
||||||
len(all_triplets_list)
|
|
||||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
|
||||||
|
|
||||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||||
logger.info("步骤 3/6: 生成实体嵌入")
|
logger.info("步骤 3/6: 生成实体嵌入")
|
||||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||||
@@ -287,8 +286,6 @@ class ExtractionOrchestrator:
|
|||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -297,7 +294,7 @@ class ExtractionOrchestrator:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _extract_statements(
|
async def _extract_statements(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||||
@@ -395,7 +392,7 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _extract_triplets(
|
async def _extract_triplets(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||||
@@ -478,7 +475,7 @@ class ExtractionOrchestrator:
|
|||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
async def _extract_temporal(
|
async def _extract_temporal(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||||
@@ -585,7 +582,7 @@ class ExtractionOrchestrator:
|
|||||||
return temporal_maps
|
return temporal_maps
|
||||||
|
|
||||||
async def _extract_emotions(
|
async def _extract_emotions(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||||
@@ -706,7 +703,7 @@ class ExtractionOrchestrator:
|
|||||||
return emotion_maps
|
return emotion_maps
|
||||||
|
|
||||||
async def _parallel_extract_and_embed(
|
async def _parallel_extract_and_embed(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
@@ -777,7 +774,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_basic_embeddings(
|
async def _generate_basic_embeddings(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
||||||
"""
|
"""
|
||||||
生成基础嵌入向量(陈述句、分块、对话)
|
生成基础嵌入向量(陈述句、分块、对话)
|
||||||
@@ -836,7 +833,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_entity_embeddings(
|
async def _generate_entity_embeddings(
|
||||||
self, triplet_maps: List[Dict[str, Any]]
|
self, triplet_maps: List[Dict[str, Any]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
生成实体嵌入向量
|
生成实体嵌入向量
|
||||||
@@ -874,17 +871,15 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _assign_extracted_data(
|
async def _assign_extracted_data(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
temporal_maps: List[Dict[str, Any]],
|
temporal_maps: List[Dict[str, Any]],
|
||||||
triplet_maps: List[Dict[str, Any]],
|
triplet_maps: List[Dict[str, Any]],
|
||||||
emotion_maps: List[Dict[str, Any]],
|
emotion_maps: List[Dict[str, Any]],
|
||||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||||
dialog_embeddings: List[List[float]],
|
dialog_embeddings: List[List[float]],
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
将提取的数据赋值到语句
|
将提取的数据赋值到语句
|
||||||
@@ -906,12 +901,12 @@ class ExtractionOrchestrator:
|
|||||||
# 确保列表长度匹配
|
# 确保列表长度匹配
|
||||||
expected_length = len(dialog_data_list)
|
expected_length = len(dialog_data_list)
|
||||||
if (
|
if (
|
||||||
len(temporal_maps) != expected_length
|
len(temporal_maps) != expected_length
|
||||||
or len(triplet_maps) != expected_length
|
or len(triplet_maps) != expected_length
|
||||||
or len(emotion_maps) != expected_length
|
or len(emotion_maps) != expected_length
|
||||||
or len(statement_embedding_maps) != expected_length
|
or len(statement_embedding_maps) != expected_length
|
||||||
or len(chunk_embedding_maps) != expected_length
|
or len(chunk_embedding_maps) != expected_length
|
||||||
or len(dialog_embeddings) != expected_length
|
or len(dialog_embeddings) != expected_length
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||||
@@ -999,7 +994,7 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _create_nodes_and_edges(
|
async def _create_nodes_and_edges(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
@@ -1007,7 +1002,7 @@ class ExtractionOrchestrator:
|
|||||||
List[ExtractedEntityNode],
|
List[ExtractedEntityNode],
|
||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
创建图数据库节点和边
|
创建图数据库节点和边
|
||||||
@@ -1083,15 +1078,19 @@ class ExtractionOrchestrator:
|
|||||||
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
|
||||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
# 添加必需的 temporal_info 字段
|
||||||
|
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||||
statement_embedding=statement.statement_embedding,
|
statement_embedding=statement.statement_embedding,
|
||||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
valid_at=statement.temporal_validity.valid_at if hasattr(statement,
|
||||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
|
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||||
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
@@ -1141,7 +1140,8 @@ class ExtractionOrchestrator:
|
|||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||||
@@ -1254,19 +1254,24 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _run_dedup_and_write_summary(
|
async def _run_dedup_and_write_summary(
|
||||||
self,
|
self,
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两阶段去重并写入汇总
|
执行两阶段去重并写入汇总
|
||||||
@@ -1415,7 +1420,6 @@ class ExtractionOrchestrator:
|
|||||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -1436,10 +1440,10 @@ class ExtractionOrchestrator:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _save_dedup_details(
|
def _save_dedup_details(
|
||||||
self,
|
self,
|
||||||
dedup_details: Dict[str, Any],
|
dedup_details: Dict[str, Any],
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||||
@@ -1537,15 +1541,16 @@ class ExtractionOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
logger.info(
|
||||||
|
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _analyze_entity_merges(
|
async def _analyze_entity_merges(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||||
@@ -1585,9 +1590,9 @@ class ExtractionOrchestrator:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def _analyze_entity_disambiguation(
|
async def _analyze_entity_disambiguation(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||||
@@ -1645,9 +1650,9 @@ class ExtractionOrchestrator:
|
|||||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||||
|
|
||||||
async def _output_relationship_creation_results(
|
async def _output_relationship_creation_results(
|
||||||
self,
|
self,
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
entity_nodes: List[ExtractedEntityNode]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
输出关系创建结果
|
输出关系创建结果
|
||||||
@@ -1681,13 +1686,13 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _send_dedup_progress_callback(
|
async def _send_dedup_progress_callback(
|
||||||
self,
|
self,
|
||||||
original_entities: int,
|
original_entities: int,
|
||||||
final_entities: int,
|
final_entities: int,
|
||||||
original_stmt_edges: int,
|
original_stmt_edges: int,
|
||||||
final_stmt_edges: int,
|
final_stmt_edges: int,
|
||||||
original_ent_edges: int,
|
original_ent_edges: int,
|
||||||
final_ent_edges: int,
|
final_ent_edges: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||||
@@ -1715,7 +1720,8 @@ class ExtractionOrchestrator:
|
|||||||
"original_count": original_entities,
|
"original_count": original_entities,
|
||||||
"final_count": final_entities,
|
"final_count": final_entities,
|
||||||
"reduced_count": entities_reduced,
|
"reduced_count": entities_reduced,
|
||||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
"reduction_rate": round(entities_reduced / original_entities * 100,
|
||||||
|
1) if original_entities > 0 else 0,
|
||||||
},
|
},
|
||||||
"statement_entity_edges": {
|
"statement_entity_edges": {
|
||||||
"original_count": original_stmt_edges,
|
"original_count": original_stmt_edges,
|
||||||
@@ -1790,7 +1796,8 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
disamb_examples.append({
|
disamb_examples.append({
|
||||||
"entity1_name": entity_name,
|
"entity1_name": entity_name,
|
||||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
|
||||||
|
"").strip() if "vs" in disamb_type else "未知",
|
||||||
"entity2_name": entity_name,
|
"entity2_name": entity_name,
|
||||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||||
"description": f"{entity_name},消歧区分成功"
|
"description": f"{entity_name},消歧区分成功"
|
||||||
@@ -1815,9 +1822,9 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从测试数据生成分块对话
|
"""从测试数据生成分块对话
|
||||||
|
|
||||||
@@ -1924,10 +1931,10 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_data(
|
def preprocess_data(
|
||||||
input_path: Optional[str] = None,
|
input_path: Optional[str] = None,
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
indices: Optional[List[int]] = None
|
indices: Optional[List[int]] = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""数据预处理
|
"""数据预处理
|
||||||
|
|
||||||
@@ -1946,7 +1953,8 @@ def preprocess_data(
|
|||||||
)
|
)
|
||||||
preprocessor = DataPreprocessor()
|
preprocessor = DataPreprocessor()
|
||||||
try:
|
try:
|
||||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
|
||||||
|
skip_cleaning=skip_cleaning, indices=indices)
|
||||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||||
return cleaned_data
|
return cleaned_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1955,9 +1963,9 @@ def preprocess_data(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_from_preprocessed(
|
async def get_chunked_dialogs_from_preprocessed(
|
||||||
data: List[DialogData],
|
data: List[DialogData],
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从预处理后的数据中生成分块
|
"""从预处理后的数据中生成分块
|
||||||
|
|
||||||
@@ -1988,15 +1996,15 @@ async def get_chunked_dialogs_from_preprocessed(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_with_preprocessing(
|
async def get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "default",
|
end_user_id: str = "default",
|
||||||
user_id: str = "default",
|
user_id: str = "default",
|
||||||
apply_id: str = "default",
|
apply_id: str = "default",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
input_data_path: Optional[str] = None,
|
input_data_path: Optional[str] = None,
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
pruning_config: Optional[Dict] = None,
|
pruning_config: Optional[Dict] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""包含数据预处理步骤的完整分块流程
|
"""包含数据预处理步骤的完整分块流程
|
||||||
|
|
||||||
@@ -2046,7 +2054,8 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
if pruning_config:
|
if pruning_config:
|
||||||
# 使用传入的配置
|
# 使用传入的配置
|
||||||
config = PruningConfig(**pruning_config)
|
config = PruningConfig(**pruning_config)
|
||||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
logger.debug(
|
||||||
|
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||||
else:
|
else:
|
||||||
# 使用默认配置(关闭剪枝)
|
# 使用默认配置(关闭剪枝)
|
||||||
config = None
|
config = None
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
|||||||
response_model=MemorySummaryResponse,
|
response_model=MemorySummaryResponse,
|
||||||
)
|
)
|
||||||
summary_text = structured.summary.strip()
|
summary_text = structured.summary.strip()
|
||||||
|
|
||||||
# Generate title and type for the summary
|
# Generate title and type for the summary
|
||||||
title = None
|
title = None
|
||||||
episodic_type = None
|
episodic_type = None
|
||||||
|
|||||||
@@ -374,7 +374,9 @@ class VariablePool:
|
|||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|
||||||
def is_file_variable(self, selector):
|
def is_file_variable(self, selector):
|
||||||
variable_struct = self._get_variable_struct(selector)
|
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||||
|
if variable_struct is None:
|
||||||
|
return False
|
||||||
if isinstance(variable_struct, FileVariable):
|
if isinstance(variable_struct, FileVariable):
|
||||||
return True
|
return True
|
||||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||||
|
|||||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
|||||||
async def process_message(
|
async def process_message(
|
||||||
api_config: ModelInfo,
|
api_config: ModelInfo,
|
||||||
content: str | dict | FileObject,
|
content: str | dict | FileObject,
|
||||||
end_user_id: str,
|
|
||||||
enable_file=False
|
enable_file=False
|
||||||
) -> list | str | None:
|
) -> list | str | None:
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||||
file_obj = FileInput(
|
file_obj = FileInput(
|
||||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
|||||||
)
|
)
|
||||||
file_obj.set_content(content.get_content())
|
file_obj.set_content(content.get_content())
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodel_service.process_files(
|
||||||
end_user_id,
|
|
||||||
[file_obj],
|
[file_obj],
|
||||||
)
|
)
|
||||||
content.set_content(file_obj.get_content())
|
content.set_content(file_obj.get_content())
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
|||||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||||
|
|
||||||
messages_config = self.typed_config.messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
|||||||
content_template = msg_config.content
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
content = self._render_template(content_template, variable_pool)
|
content = self._render_template(content_template, variable_pool)
|
||||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
|||||||
"content": await self.process_message(
|
"content": await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
content,
|
content,
|
||||||
user_id,
|
|
||||||
self.typed_config.vision,
|
self.typed_config.vision,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
|||||||
message["content"] = await self.process_message(
|
message["content"] = await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
message["content"],
|
message["content"],
|
||||||
user_id,
|
|
||||||
self.typed_config.vision
|
self.typed_config.vision
|
||||||
)
|
)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=messages,
|
message=messages,
|
||||||
|
file_messages=multimodal_memories,
|
||||||
config_id=str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
storage_type=state["memory_storage_type"],
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=state["user_rag_memory_id"]
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
|||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||||
|
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||||
|
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||||
|
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||||
|
|
||||||
# 记忆萃取引擎配置
|
# 记忆萃取引擎配置
|
||||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||||
|
|||||||
@@ -9,21 +9,22 @@ Classes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from uuid import UUID
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import desc, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_config_logger, get_db_logger
|
from app.core.logging_config import get_config_logger, get_db_logger
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
ConfigKey,
|
|
||||||
ConfigParamsCreate,
|
ConfigParamsCreate,
|
||||||
ConfigUpdate,
|
ConfigUpdate,
|
||||||
ConfigUpdateExtracted,
|
ConfigUpdateExtracted,
|
||||||
ConfigUpdateForget,
|
ConfigUpdateForget,
|
||||||
)
|
)
|
||||||
from sqlalchemy import desc, select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
|||||||
return memory_config_obj
|
return memory_config_obj
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -491,7 +492,10 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
def get_config_with_workspace(
|
||||||
|
db: Session,
|
||||||
|
config_id: uuid.UUID | int | str
|
||||||
|
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||||
"""Get memory config and its associated workspace information
|
"""Get memory config and its associated workspace information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -506,8 +510,6 @@ class MemoryConfigRepository:
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
@@ -594,7 +596,7 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
db_logger.debug(
|
db_logger.debug(
|
||||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||||
return (config, workspace)
|
return config, workspace
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise known business exceptions
|
# Re-raise known business exceptions
|
||||||
@@ -739,9 +741,9 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_with_fallback(
|
def get_with_fallback(
|
||||||
db: Session,
|
db: Session,
|
||||||
config_id: Optional[uuid.UUID],
|
config_id: Optional[uuid.UUID],
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> Optional[MemoryConfig]:
|
) -> Optional[MemoryConfig]:
|
||||||
"""获取记忆配置,支持回退到工作空间默认配置
|
"""获取记忆配置,支持回退到工作空间默认配置
|
||||||
|
|
||||||
@@ -771,4 +773,3 @@ class MemoryConfigRepository:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
|
||||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||||
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||||
|
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
|||||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add dialogue nodes to Neo4j database.
|
"""Add dialogue nodes to Neo4j database.
|
||||||
|
|
||||||
@@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
print(f"Error creating statement nodes: {e}")
|
print(f"Error creating statement nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add chunk nodes to Neo4j in batch.
|
"""Add chunk nodes to Neo4j in batch.
|
||||||
|
|
||||||
@@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
||||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
List[str]]:
|
||||||
"""Add memory summary nodes to Neo4j in batch.
|
"""Add memory summary nodes to Neo4j in batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_perceptual_nodes(
|
||||||
|
perceptuals: list,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
embedder_client=None,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Add perceptual memory nodes to Neo4j in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
||||||
|
connector: Neo4j connector instance
|
||||||
|
embedder_client: Optional embedder client for generating summary embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created node UUIDs or None if failed
|
||||||
|
"""
|
||||||
|
if not perceptuals:
|
||||||
|
print("No perceptual nodes to add")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
flattened = []
|
||||||
|
for p in perceptuals:
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
# 生成 summary embedding(如果有 embedder_client)
|
||||||
|
summary_embedding = None
|
||||||
|
if embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
flattened.append({
|
||||||
|
"id": str(p.id),
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"perceptual_type": p.perceptual_type,
|
||||||
|
"file_path": p.file_path or "",
|
||||||
|
"file_name": p.file_name or "",
|
||||||
|
"file_ext": p.file_ext or "",
|
||||||
|
"summary": p.summary or "",
|
||||||
|
"keywords": content_meta.get("keywords", []),
|
||||||
|
"topic": content_meta.get("topic", ""),
|
||||||
|
"domain": content_meta.get("domain", ""),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
"summary_embedding": summary_embedding,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await connector.execute_query(
|
||||||
|
PERCEPTUAL_NODE_SAVE,
|
||||||
|
perceptuals=flattened,
|
||||||
|
)
|
||||||
|
created_uuids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
||||||
|
return created_uuids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_perceptual_dialogue_edges(
|
||||||
|
perceptuals: list,
|
||||||
|
dialog_id: str,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Add edges between Perceptual nodes and Dialogue nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
perceptuals: List of MemoryPerceptualModel objects
|
||||||
|
dialog_id: The dialogue ID (or ref_id) to link to
|
||||||
|
connector: Neo4j connector instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created edge element IDs or None if failed
|
||||||
|
"""
|
||||||
|
if not perceptuals or not dialog_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
edges = []
|
||||||
|
for p in perceptuals:
|
||||||
|
edges.append({
|
||||||
|
"perceptual_id": str(p.id),
|
||||||
|
"dialog_id": dialog_id,
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await connector.execute_query(
|
||||||
|
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
||||||
|
edges=edges,
|
||||||
|
)
|
||||||
|
created_ids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
||||||
|
return created_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1323,3 +1323,36 @@ RETURN s.statement AS statement,
|
|||||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 感知记忆节点保存
|
||||||
|
PERCEPTUAL_NODE_SAVE = """
|
||||||
|
UNWIND $perceptuals AS p
|
||||||
|
MERGE (n:Perceptual {id: p.id})
|
||||||
|
SET n += {
|
||||||
|
id: p.id,
|
||||||
|
end_user_id: p.end_user_id,
|
||||||
|
perceptual_type: p.perceptual_type,
|
||||||
|
file_path: p.file_path,
|
||||||
|
file_name: p.file_name,
|
||||||
|
file_ext: p.file_ext,
|
||||||
|
summary: p.summary,
|
||||||
|
keywords: p.keywords,
|
||||||
|
topic: p.topic,
|
||||||
|
domain: p.domain,
|
||||||
|
created_at: p.created_at,
|
||||||
|
summary_embedding: p.summary_embedding
|
||||||
|
}
|
||||||
|
RETURN n.id AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 感知记忆与对话的关联边
|
||||||
|
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||||
|
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
||||||
|
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
||||||
|
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
||||||
|
SET r.end_user_id = edge.end_user_id,
|
||||||
|
r.created_at = edge.created_at
|
||||||
|
RETURN elementId(r) AS uuid
|
||||||
|
"""
|
||||||
|
|||||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
|||||||
|
|
||||||
rerank_model_id: Optional[UUID] = None
|
rerank_model_id: Optional[UUID] = None
|
||||||
rerank_model_name: Optional[str] = None
|
rerank_model_name: Optional[str] = None
|
||||||
|
video_model_id: Optional[UUID] = None
|
||||||
|
video_model_name: Optional[str] = None
|
||||||
|
vision_model_id: Optional[UUID] = None
|
||||||
|
vision_model_name: Optional[str] = None
|
||||||
|
audio_model_id: Optional[UUID] = None
|
||||||
|
audio_model_name: Optional[str] = None
|
||||||
|
|
||||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class AppChatService:
|
|||||||
model_type=ModelType.LLM
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 调用 Agent(支持多模态)
|
# 调用 Agent(支持多模态)
|
||||||
@@ -339,7 +339,7 @@ class AppChatService:
|
|||||||
model_type=ModelType.LLM
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||||
|
|||||||
@@ -600,7 +600,7 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
@@ -836,7 +836,7 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
|
|||||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.cache import InterestMemoryCache
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.messages_tools import (
|
from app.core.memory.agent.utils.messages_tools import (
|
||||||
merge_multiple_search_results,
|
merge_multiple_search_results,
|
||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
|
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||||
@@ -271,6 +274,7 @@ class MemoryAgentService:
|
|||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
|
file_messages: list[dict],
|
||||||
config_id: Optional[uuid.UUID] | int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
@@ -283,6 +287,7 @@ class MemoryAgentService:
|
|||||||
Args:
|
Args:
|
||||||
end_user_id: Group identifier (also used as end_user_id)
|
end_user_id: Group identifier (also used as end_user_id)
|
||||||
messages: Message to write
|
messages: Message to write
|
||||||
|
files: Files to write
|
||||||
config_id: Configuration ID from database
|
config_id: Configuration ID from database
|
||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
@@ -342,48 +347,52 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
perceptual_serivce = MemoryPerceptualService(db)
|
||||||
|
file_content = []
|
||||||
|
for message in file_messages:
|
||||||
|
for file in message["files"]:
|
||||||
|
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
file=FileInput(**file)
|
||||||
|
)
|
||||||
|
file_content.append(file_object)
|
||||||
|
|
||||||
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
# For RAG storage, convert messages to single string
|
# For RAG storage, convert messages to single string
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
return "success"
|
return "success"
|
||||||
else:
|
else:
|
||||||
async with make_write_graph() as graph:
|
await write_neo4j(
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
end_user_id=end_user_id,
|
||||||
# Convert structured messages to LangChain messages
|
messages=messages,
|
||||||
langchain_messages = []
|
file_content=file_content,
|
||||||
for msg in messages:
|
memory_config=memory_config,
|
||||||
if msg['role'] == 'user':
|
ref_id='',
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
language=language
|
||||||
elif msg['role'] == 'assistant':
|
)
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
for lang in ["zh", "en"]:
|
||||||
print(100 * '-')
|
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||||
print(langchain_messages)
|
end_user_id, lang
|
||||||
print(100 * '-')
|
)
|
||||||
# 初始状态 - 包含所有必要字段
|
if deleted:
|
||||||
initial_state = {
|
logger.info(
|
||||||
"messages": langchain_messages,
|
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||||
"end_user_id": end_user_id,
|
return self.writer_messages_deal(
|
||||||
"memory_config": memory_config,
|
"success",
|
||||||
"language": language
|
start_time,
|
||||||
}
|
end_user_id,
|
||||||
|
config_id,
|
||||||
# 获取节点更新信息
|
message_text,
|
||||||
async for update_event in graph.astream(
|
{
|
||||||
initial_state,
|
"status": "success",
|
||||||
stream_mode="updates",
|
"data": messages,
|
||||||
config=config
|
"config_id": memory_config.config_id,
|
||||||
):
|
"config_name": memory_config.config_name
|
||||||
for node_name, node_data in update_event.items():
|
}
|
||||||
if 'save_neo4j' == node_name:
|
)
|
||||||
massages = node_data
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
# Convert messages back to string for logging
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
|
||||||
contents)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ class MemoryAPIService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def validate_end_user(
|
def validate_end_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> EndUser:
|
) -> EndUser:
|
||||||
"""Validate that end_user exists and belongs to the workspace.
|
"""Validate that end_user exists and belongs to the workspace.
|
||||||
|
|
||||||
@@ -125,13 +125,14 @@ class MemoryAPIService:
|
|||||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
files: Optional[list]=None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Write memory with validation.
|
||||||
|
|
||||||
@@ -153,6 +154,8 @@ class MemoryAPIService:
|
|||||||
ResourceNotFoundException: If end_user not found
|
ResourceNotFoundException: If end_user not found
|
||||||
BusinessException: If end_user not in authorized workspace or write fails
|
BusinessException: If end_user not in authorized workspace or write fails
|
||||||
"""
|
"""
|
||||||
|
if files is None:
|
||||||
|
files = list()
|
||||||
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
@@ -171,7 +174,8 @@ class MemoryAPIService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id or ""
|
user_rag_memory_id=user_rag_memory_id or "",
|
||||||
|
files=files
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||||
@@ -206,14 +210,14 @@ class MemoryAPIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
search_switch: str = "0",
|
search_switch: str = "0",
|
||||||
config_id: str = "",
|
config_id: str = "",
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory with validation.
|
||||||
|
|
||||||
@@ -244,7 +248,6 @@ class MemoryAPIService:
|
|||||||
# Update end user's memory_config_id
|
# Update end user's memory_config_id
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().read_memory(
|
result = await MemoryAgentService().read_memory(
|
||||||
@@ -282,8 +285,8 @@ class MemoryAPIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def list_memory_configs(
|
def list_memory_configs(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""List all memory configs for a workspace.
|
"""List all memory configs for a workspace.
|
||||||
|
|
||||||
|
|||||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def load_memory_config(
|
def load_memory_config(
|
||||||
self,
|
self,
|
||||||
config_id: Optional[UUID] = None,
|
config_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None,
|
workspace_id: Optional[UUID] = None,
|
||||||
service_name: str = "MemoryConfigService",
|
service_name: str = "MemoryConfigService",
|
||||||
) -> MemoryConfig:
|
) -> MemoryConfig:
|
||||||
"""
|
"""
|
||||||
Load memory configuration from database with optional fallback.
|
Load memory configuration from database with optional fallback.
|
||||||
@@ -194,8 +194,8 @@ class MemoryConfigService:
|
|||||||
try:
|
try:
|
||||||
# Use get_config_with_fallback if workspace_id is provided
|
# Use get_config_with_fallback if workspace_id is provided
|
||||||
memory_config = None
|
memory_config = None
|
||||||
|
validated_config_id = None
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
validated_config_id = None
|
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
validated_config_id = _validate_config_id(config_id, self.db)
|
validated_config_id = _validate_config_id(config_id, self.db)
|
||||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
|||||||
|
|
||||||
# Helper function to validate model with workspace fallback
|
# Helper function to validate model with workspace fallback
|
||||||
def _validate_model_with_fallback(
|
def _validate_model_with_fallback(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
workspace_default: str,
|
workspace_default: str,
|
||||||
required: bool = False
|
required: bool = False
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Validate model ID, falling back to workspace default if invalid.
|
"""Validate model ID, falling back to workspace default if invalid.
|
||||||
|
|
||||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
|||||||
if memory_config.rerank_id or workspace.rerank:
|
if memory_config.rerank_id or workspace.rerank:
|
||||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||||
|
|
||||||
|
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.vision_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.audio_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_uuid, video_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.video_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
# Create immutable MemoryConfig object
|
# Create immutable MemoryConfig object
|
||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
config_id=memory_config.config_id,
|
config_id=memory_config.config_id,
|
||||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
|||||||
embedding_model_name=embedding_name,
|
embedding_model_name=embedding_name,
|
||||||
rerank_model_id=rerank_uuid,
|
rerank_model_id=rerank_uuid,
|
||||||
rerank_model_name=rerank_name,
|
rerank_model_name=rerank_name,
|
||||||
|
video_model_id=video_uuid,
|
||||||
|
video_model_name=video_name,
|
||||||
|
vision_model_id=vision_uuid,
|
||||||
|
vision_model_name=vision_name,
|
||||||
|
audio_model_id=audio_uuid,
|
||||||
|
audio_model_name=audio_name,
|
||||||
storage_type=workspace.storage_type or "neo4j",
|
storage_type=workspace.storage_type or "neo4j",
|
||||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
|||||||
reflexion_baseline=memory_config.baseline or "Time",
|
reflexion_baseline=memory_config.baseline or "Time",
|
||||||
loaded_at=datetime.now(),
|
loaded_at=datetime.now(),
|
||||||
# Pipeline config: Deduplication
|
# Pipeline config: Deduplication
|
||||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
enable_llm_dedup_blockwise=bool(
|
||||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||||
|
enable_llm_disambiguation=bool(
|
||||||
|
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||||
# Pipeline config: Statement extraction
|
# Pipeline config: Statement extraction
|
||||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
statement_granularity=int(
|
||||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
include_dialogue_context=bool(
|
||||||
|
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||||
|
max_dialogue_context_chars=int(
|
||||||
|
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||||
# Pipeline config: Forgetting engine
|
# Pipeline config: Forgetting engine
|
||||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||||
# Pipeline config: Pruning
|
# Pipeline config: Pruning
|
||||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
pruning_enabled=bool(
|
||||||
|
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
pruning_threshold=float(
|
||||||
|
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_workspace_default_config(
|
def get_workspace_default_config(
|
||||||
self,
|
self,
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get workspace default memory config.
|
"""Get workspace default memory config.
|
||||||
|
|
||||||
@@ -623,9 +665,9 @@ class MemoryConfigService:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def get_config_with_fallback(
|
def get_config_with_fallback(
|
||||||
self,
|
self,
|
||||||
memory_config_id: Optional[UUID],
|
memory_config_id: Optional[UUID],
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get memory config with fallback to workspace default.
|
"""Get memory config with fallback to workspace default.
|
||||||
|
|
||||||
@@ -663,9 +705,9 @@ class MemoryConfigService:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def delete_config(
|
def delete_config(
|
||||||
self,
|
self,
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
force: bool = False
|
force: bool = False
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Delete memory config with protection against in-use configs.
|
"""Delete memory config with protection against in-use configs.
|
||||||
|
|
||||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
|||||||
# ==================== 记忆配置提取方法 ====================
|
# ==================== 记忆配置提取方法 ====================
|
||||||
|
|
||||||
def extract_memory_config_id(
|
def extract_memory_config_id(
|
||||||
self,
|
self,
|
||||||
app_type: str,
|
app_type: str,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||||
|
|
||||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_agent(
|
def _extract_memory_config_id_from_agent(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Agent 应用配置中提取 memory_config_id
|
"""从 Agent 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_workflow(
|
def _extract_memory_config_id_from_workflow(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Workflow 应用配置中提取 memory_config_id
|
"""从 Workflow 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models import FileMetadata
|
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType, FileInput
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
from app.schemas.memory_perceptual_schema import (
|
from app.schemas.memory_perceptual_schema import (
|
||||||
PerceptualQuerySchema,
|
PerceptualQuerySchema,
|
||||||
PerceptualTimelineResponse,
|
PerceptualTimelineResponse,
|
||||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
|||||||
AudioModal, Content, VideoModal, TextModal
|
AudioModal, Content, VideoModal, TextModal
|
||||||
)
|
)
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
|||||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||||
|
|
||||||
|
def _get_mutlimodal_client(
|
||||||
|
self,
|
||||||
|
file_type: FileType,
|
||||||
|
config: MemoryConfig
|
||||||
|
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||||
|
model_config = None
|
||||||
|
if file_type == FileType.AUDIO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.audio_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.VIDEO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.video_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.DOCUMENT:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.llm_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.IMAGE:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.vision_model_id
|
||||||
|
)
|
||||||
|
llm = None
|
||||||
|
if model_config:
|
||||||
|
llm = RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_config.model_name,
|
||||||
|
provider=model_config.provider,
|
||||||
|
api_key=model_config.api_key,
|
||||||
|
base_url=model_config.api_base,
|
||||||
|
is_omni=model_config.is_omni
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return llm, model_config
|
||||||
|
|
||||||
async def generate_perceptual_memory(
|
async def generate_perceptual_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
model_config: ModelInfo,
|
memory_config: MemoryConfig,
|
||||||
file_type: str,
|
file: FileInput
|
||||||
file_url: str,
|
|
||||||
file_message: dict,
|
|
||||||
):
|
):
|
||||||
memories = self.repository.get_by_url(file_url)
|
memories = self.repository.get_by_url(file.url)
|
||||||
if memories:
|
if memories:
|
||||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||||
memory_cache = memories[0]
|
memory_cache = memories[0]
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||||
file_path=memory_cache.file_path,
|
file_path=memory_cache.file_path,
|
||||||
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
|
|||||||
meta_data=memory_cache.meta_data
|
meta_data=memory_cache.meta_data
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
return
|
else:
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
for memory in memories:
|
||||||
|
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||||
|
return memory
|
||||||
|
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||||
|
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||||
model_name=model_config.model_name,
|
model_name=model_config.model_name,
|
||||||
provider=model_config.provider,
|
provider=model_config.provider,
|
||||||
api_key=model_config.api_key,
|
api_key=model_config.api_key,
|
||||||
base_url=model_config.api_base,
|
api_base=model_config.api_base,
|
||||||
is_omni=model_config.is_omni
|
is_omni=model_config.is_omni,
|
||||||
), type=model_config.model_type)
|
capability=model_config.capability,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
))
|
||||||
|
file_message = await multimodel_service.process_files(
|
||||||
|
files=[file]
|
||||||
|
)
|
||||||
|
if file_message:
|
||||||
|
file_message = file_message[0]
|
||||||
try:
|
try:
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||||
messages = [
|
messages = [
|
||||||
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
|
|||||||
]}
|
]}
|
||||||
]
|
]
|
||||||
result = await llm.ainvoke(messages)
|
result = await llm.ainvoke(messages)
|
||||||
content = json_repair.repair_json(result.content, return_objects=True)
|
content = result.content
|
||||||
path = urlparse(file_url).path
|
final_output = ""
|
||||||
|
if isinstance(content, list):
|
||||||
|
for msg in content:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
final_output += msg.get("text", "")
|
||||||
|
elif isinstance(msg, str):
|
||||||
|
final_output += msg
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
final_output += content.get("text", "")
|
||||||
|
elif isinstance(content, str):
|
||||||
|
final_output = content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||||
|
content = json_repair.repair_json(final_output, return_objects=True)
|
||||||
|
path = urlparse(file.url).path
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename = unquote(filename)
|
filename = unquote(filename)
|
||||||
file_ext = os.path.splitext(filename)[1]
|
file_ext = os.path.splitext(filename)[1]
|
||||||
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
business_logger.debug(f"Remote file, file_id={filename}")
|
business_logger.debug(f"Remote file, file_id={filename}")
|
||||||
if not file_ext:
|
if not file_ext:
|
||||||
if file_type == FileType.AUDIO:
|
if file.type == FileType.AUDIO:
|
||||||
file_ext = ".mp3"
|
file_ext = ".mp3"
|
||||||
elif file_type == FileType.VIDEO:
|
elif file.type == FileType.VIDEO:
|
||||||
file_ext = ".mp4"
|
file_ext = ".mp4"
|
||||||
elif file_type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
file_ext = ".txt"
|
file_ext = ".txt"
|
||||||
elif file_type == FileType.IMAGE:
|
elif file.type == FileType.IMAGE:
|
||||||
file_ext = ".jpg"
|
file_ext = ".jpg"
|
||||||
filename += file_ext
|
filename += file_ext
|
||||||
file_content = {
|
file_content = {
|
||||||
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
|
|||||||
"topic": content.get("topic"),
|
"topic": content.get("topic"),
|
||||||
"domain": content.get("domain")
|
"domain": content.get("domain")
|
||||||
}
|
}
|
||||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"scene": content.get("scene", [])
|
"scene": content.get("scene", [])
|
||||||
}
|
}
|
||||||
elif file_type in [FileType.DOCUMENT]:
|
elif file.type in [FileType.DOCUMENT]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"section_count": content.get("section_count", 0),
|
"section_count": content.get("section_count", 0),
|
||||||
"title": content.get("title", ""),
|
"title": content.get("title", ""),
|
||||||
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
|
|||||||
file_modalities = {
|
file_modalities = {
|
||||||
"speaker_count": content.get("speaker_count", 0)
|
"speaker_count": content.get("speaker_count", 0)
|
||||||
}
|
}
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||||
file_path=file_url,
|
file_path=file.url,
|
||||||
file_name=filename,
|
file_name=filename,
|
||||||
file_ext=file_ext,
|
file_ext=file_ext,
|
||||||
summary=content.get('summary', ""),
|
summary=content.get('summary', ""),
|
||||||
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
|
|||||||
@@ -9,14 +9,12 @@
|
|||||||
- OpenAI: 支持 URL 和 base64 格式
|
- OpenAI: 支持 URL 和 base64 格式
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
import csv
|
||||||
import io
|
import io
|
||||||
import uuid
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import httpx
|
import httpx
|
||||||
import magic
|
import magic
|
||||||
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
|
|||||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||||
from app.tasks import write_perceptual_memory
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
|||||||
|
|
||||||
async def process_files(
|
async def process_files(
|
||||||
self,
|
self,
|
||||||
end_user_id: uuid.UUID | str,
|
|
||||||
files: Optional[List[FileInput]],
|
files: Optional[List[FileInput]],
|
||||||
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理文件列表,返回 LLM 可用的格式
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 用户ID
|
|
||||||
files: 文件输入列表
|
files: 文件输入列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
|||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
if isinstance(end_user_id, uuid.UUID):
|
|
||||||
end_user_id = str(end_user_id)
|
|
||||||
|
|
||||||
# 获取对应的策略
|
# 获取对应的策略
|
||||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
|||||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||||
is_support, content = await self._process_image(file, strategy)
|
is_support, content = await self._process_image(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
is_support, content = await self._process_document(file, strategy)
|
is_support, content = await self._process_document(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
is_support, content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||||
is_support, content = await self._process_video(file, strategy)
|
is_support, content = await self._process_video(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"不支持的文件类型: {file.type}")
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -418,17 +402,6 @@ class MultimodalService:
|
|||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def write_perceptual_memory(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
file_type: str,
|
|
||||||
file_url: str,
|
|
||||||
file_message: dict
|
|
||||||
):
|
|
||||||
"""写入感知记忆"""
|
|
||||||
if end_user_id and self.api_config:
|
|
||||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理图片文件
|
处理图片文件
|
||||||
|
|||||||
@@ -1080,12 +1080,14 @@ def write_message_task(
|
|||||||
config_id: str | int,
|
config_id: str | int,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str,
|
user_rag_memory_id: str,
|
||||||
|
file_messages: list[dict] | None,
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
|
file_messages: Files to write
|
||||||
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
user_rag_memory_id: User RAG memory ID
|
user_rag_memory_id: User RAG memory ID
|
||||||
@@ -1097,6 +1099,8 @@ def write_message_task(
|
|||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
|
if file_messages is None:
|
||||||
|
file_messages = []
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||||
@@ -1142,7 +1146,7 @@ def write_message_task(
|
|||||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
|
||||||
user_rag_memory_id, language)
|
user_rag_memory_id, language)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user