feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract
|
|||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import uuid
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -13,28 +14,31 @@ from dotenv import load_dotenv
|
|||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||||
|
memory_summary_generation
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
|
from app.models import MemoryPerceptualModel
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
||||||
|
add_perceptual_dialogue_edges
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
file_content: list[MemoryPerceptualModel],
|
||||||
language: str = "zh",
|
ref_id: str = "",
|
||||||
|
language: str = "zh",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
@@ -43,9 +47,12 @@ async def write(
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
file_content: mutilmodal message list
|
||||||
|
ref_id: Reference ID, defaults to ""
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||||
"""
|
"""
|
||||||
|
if not ref_id:
|
||||||
|
ref_id = uuid.uuid4().hex
|
||||||
# Extract config values
|
# Extract config values
|
||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
@@ -99,14 +106,14 @@ async def write(
|
|||||||
if memory_config.scene_id:
|
if memory_config.scene_id:
|
||||||
try:
|
try:
|
||||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
ontology_types = load_ontology_types_for_scene(
|
ontology_types = load_ontology_types_for_scene(
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
workspace_id=memory_config.workspace_id,
|
workspace_id=memory_config.workspace_id,
|
||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
if ontology_types:
|
if ontology_types:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||||
@@ -173,7 +180,8 @@ async def write(
|
|||||||
schedule_clustering_after_write(
|
schedule_clustering_after_write(
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
embedding_model_id=str(
|
||||||
|
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -208,9 +216,8 @@ async def write(
|
|||||||
summaries = await memory_summary_generation(
|
summaries = await memory_summary_generation(
|
||||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||||
)
|
)
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
ms_connector = Neo4jConnector()
|
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
finally:
|
finally:
|
||||||
@@ -223,6 +230,34 @@ async def write(
|
|||||||
finally:
|
finally:
|
||||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 5: Save perceptual memory to Neo4j
|
||||||
|
step_start = time.time()
|
||||||
|
if file_content:
|
||||||
|
try:
|
||||||
|
pc_connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
created_ids = await add_perceptual_nodes(
|
||||||
|
perceptuals=file_content,
|
||||||
|
connector=pc_connector,
|
||||||
|
embedder_client=embedder_client,
|
||||||
|
)
|
||||||
|
# 如果有 ref_id,建立感知记忆与对话的关联
|
||||||
|
if ref_id and created_ids:
|
||||||
|
await add_perceptual_dialogue_edges(
|
||||||
|
perceptuals=file_content,
|
||||||
|
dialog_id=ref_id,
|
||||||
|
connector=pc_connector,
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await pc_connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
||||||
|
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# Log total pipeline time
|
# Log total pipeline time
|
||||||
total_time = time.time() - pipeline_start
|
total_time = time.time() - pipeline_start
|
||||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||||
@@ -251,4 +286,4 @@ async def write(
|
|||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
|||||||
@@ -553,3 +553,21 @@ class MemorySummaryNode(Node):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MutlimodalNode(Node):
|
||||||
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dialog_id: ID of the parent dialog
|
||||||
|
message_id: ID of the message
|
||||||
|
metadata: Additional message metadata
|
||||||
|
embedding: Optional embedding vector for the message
|
||||||
|
"""
|
||||||
|
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||||
|
message_id: str = Field(..., description="ID of the message")
|
||||||
|
summary: str = Field(..., description="The text content of the message")
|
||||||
|
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
||||||
|
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
||||||
|
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
||||||
|
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
||||||
|
|||||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
|
|
||||||
|
|
||||||
async def dedup_layers_and_merge_and_return(
|
async def dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
pipeline_config: ExtractionPipelineConfig,
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
connector: Optional[Neo4jConnector] = None,
|
connector: Optional[Neo4jConnector] = None,
|
||||||
llm_client = None,
|
llm_client=None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
dict, # 新增:返回去重详情
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两层实体去重与融合:
|
执行两层实体去重与融合:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
|||||||
response_model=MemorySummaryResponse,
|
response_model=MemorySummaryResponse,
|
||||||
)
|
)
|
||||||
summary_text = structured.summary.strip()
|
summary_text = structured.summary.strip()
|
||||||
|
|
||||||
# Generate title and type for the summary
|
# Generate title and type for the summary
|
||||||
title = None
|
title = None
|
||||||
episodic_type = None
|
episodic_type = None
|
||||||
|
|||||||
@@ -374,7 +374,9 @@ class VariablePool:
|
|||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|
||||||
def is_file_variable(self, selector):
|
def is_file_variable(self, selector):
|
||||||
variable_struct = self._get_variable_struct(selector)
|
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||||
|
if variable_struct is None:
|
||||||
|
return False
|
||||||
if isinstance(variable_struct, FileVariable):
|
if isinstance(variable_struct, FileVariable):
|
||||||
return True
|
return True
|
||||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||||
|
|||||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
|||||||
async def process_message(
|
async def process_message(
|
||||||
api_config: ModelInfo,
|
api_config: ModelInfo,
|
||||||
content: str | dict | FileObject,
|
content: str | dict | FileObject,
|
||||||
end_user_id: str,
|
|
||||||
enable_file=False
|
enable_file=False
|
||||||
) -> list | str | None:
|
) -> list | str | None:
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||||
file_obj = FileInput(
|
file_obj = FileInput(
|
||||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
|||||||
)
|
)
|
||||||
file_obj.set_content(content.get_content())
|
file_obj.set_content(content.get_content())
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodel_service.process_files(
|
||||||
end_user_id,
|
|
||||||
[file_obj],
|
[file_obj],
|
||||||
)
|
)
|
||||||
content.set_content(file_obj.get_content())
|
content.set_content(file_obj.get_content())
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
|||||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||||
|
|
||||||
messages_config = self.typed_config.messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
|||||||
content_template = msg_config.content
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
content = self._render_template(content_template, variable_pool)
|
content = self._render_template(content_template, variable_pool)
|
||||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
|||||||
"content": await self.process_message(
|
"content": await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
content,
|
content,
|
||||||
user_id,
|
|
||||||
self.typed_config.vision,
|
self.typed_config.vision,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
|||||||
message["content"] = await self.process_message(
|
message["content"] = await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
message["content"],
|
message["content"],
|
||||||
user_id,
|
|
||||||
self.typed_config.vision
|
self.typed_config.vision
|
||||||
)
|
)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=messages,
|
message=messages,
|
||||||
|
file_messages=multimodal_memories,
|
||||||
config_id=str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
storage_type=state["memory_storage_type"],
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=state["user_rag_memory_id"]
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
|||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||||
|
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||||
|
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||||
|
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||||
|
|
||||||
# 记忆萃取引擎配置
|
# 记忆萃取引擎配置
|
||||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||||
|
|||||||
@@ -9,21 +9,22 @@ Classes:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from uuid import UUID
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import desc, select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_config_logger, get_db_logger
|
from app.core.logging_config import get_config_logger, get_db_logger
|
||||||
from app.models.memory_config_model import MemoryConfig
|
from app.models.memory_config_model import MemoryConfig
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
from app.schemas.memory_storage_schema import (
|
from app.schemas.memory_storage_schema import (
|
||||||
ConfigKey,
|
|
||||||
ConfigParamsCreate,
|
ConfigParamsCreate,
|
||||||
ConfigUpdate,
|
ConfigUpdate,
|
||||||
ConfigUpdateExtracted,
|
ConfigUpdateExtracted,
|
||||||
ConfigUpdateForget,
|
ConfigUpdateForget,
|
||||||
)
|
)
|
||||||
from sqlalchemy import desc, select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
|||||||
return memory_config_obj
|
return memory_config_obj
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -491,7 +492,10 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
def get_config_with_workspace(
|
||||||
|
db: Session,
|
||||||
|
config_id: uuid.UUID | int | str
|
||||||
|
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||||
"""Get memory config and its associated workspace information
|
"""Get memory config and its associated workspace information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -506,8 +510,6 @@ class MemoryConfigRepository:
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
@@ -594,7 +596,7 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
db_logger.debug(
|
db_logger.debug(
|
||||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||||
return (config, workspace)
|
return config, workspace
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Re-raise known business exceptions
|
# Re-raise known business exceptions
|
||||||
@@ -630,7 +632,7 @@ class MemoryConfigRepository:
|
|||||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||||
"""
|
"""
|
||||||
from app.models.ontology_scene import OntologyScene
|
from app.models.ontology_scene import OntologyScene
|
||||||
|
|
||||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -694,7 +696,7 @@ class MemoryConfigRepository:
|
|||||||
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
|
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 优先查找显式标记为默认的配置
|
# 优先查找显式标记为默认的配置
|
||||||
stmt = (
|
stmt = (
|
||||||
@@ -706,13 +708,13 @@ class MemoryConfigRepository:
|
|||||||
)
|
)
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
config = db.scalars(stmt).first()
|
config = db.scalars(stmt).first()
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
|
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
|
||||||
return config
|
return config
|
||||||
|
|
||||||
# 回退:获取最早创建的活跃配置
|
# 回退:获取最早创建的活跃配置
|
||||||
stmt = (
|
stmt = (
|
||||||
select(MemoryConfig)
|
select(MemoryConfig)
|
||||||
@@ -723,25 +725,25 @@ class MemoryConfigRepository:
|
|||||||
.order_by(MemoryConfig.created_at.asc())
|
.order_by(MemoryConfig.created_at.asc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
config = db.scalars(stmt).first()
|
config = db.scalars(stmt).first()
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
|
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
|
||||||
else:
|
else:
|
||||||
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
|
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_with_fallback(
|
def get_with_fallback(
|
||||||
db: Session,
|
db: Session,
|
||||||
config_id: Optional[uuid.UUID],
|
config_id: Optional[uuid.UUID],
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> Optional[MemoryConfig]:
|
) -> Optional[MemoryConfig]:
|
||||||
"""获取记忆配置,支持回退到工作空间默认配置
|
"""获取记忆配置,支持回退到工作空间默认配置
|
||||||
|
|
||||||
@@ -756,19 +758,18 @@ class MemoryConfigRepository:
|
|||||||
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
|
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
|
||||||
"""
|
"""
|
||||||
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
|
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
|
||||||
|
|
||||||
if not config_id:
|
if not config_id:
|
||||||
db_logger.debug("config_id 为空,使用工作空间默认配置")
|
db_logger.debug("config_id 为空,使用工作空间默认配置")
|
||||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||||
|
|
||||||
config = db.get(MemoryConfig, config_id)
|
config = db.get(MemoryConfig, config_id)
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
db_logger.warning(
|
db_logger.warning(
|
||||||
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
|
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
|
||||||
|
|
||||||
|
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
|
||||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||||
|
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||||
|
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
|||||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add dialogue nodes to Neo4j database.
|
"""Add dialogue nodes to Neo4j database.
|
||||||
|
|
||||||
@@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
print(f"Error creating statement nodes: {e}")
|
print(f"Error creating statement nodes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||||
"""Add chunk nodes to Neo4j in batch.
|
"""Add chunk nodes to Neo4j in batch.
|
||||||
|
|
||||||
@@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
||||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
List[str]]:
|
||||||
"""Add memory summary nodes to Neo4j in batch.
|
"""Add memory summary nodes to Neo4j in batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||||
"config_id": s.config_id, # 添加 config_id
|
"config_id": s.config_id, # 添加 config_id
|
||||||
})
|
})
|
||||||
|
|
||||||
result = await connector.execute_query(
|
result = await connector.execute_query(
|
||||||
MEMORY_SUMMARY_NODE_SAVE,
|
MEMORY_SUMMARY_NODE_SAVE,
|
||||||
summaries=flattened
|
summaries=flattened
|
||||||
@@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_perceptual_nodes(
|
||||||
|
perceptuals: list,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
embedder_client=None,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Add perceptual memory nodes to Neo4j in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
||||||
|
connector: Neo4j connector instance
|
||||||
|
embedder_client: Optional embedder client for generating summary embeddings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created node UUIDs or None if failed
|
||||||
|
"""
|
||||||
|
if not perceptuals:
|
||||||
|
print("No perceptual nodes to add")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
flattened = []
|
||||||
|
for p in perceptuals:
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
# 生成 summary embedding(如果有 embedder_client)
|
||||||
|
summary_embedding = None
|
||||||
|
if embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
flattened.append({
|
||||||
|
"id": str(p.id),
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"perceptual_type": p.perceptual_type,
|
||||||
|
"file_path": p.file_path or "",
|
||||||
|
"file_name": p.file_name or "",
|
||||||
|
"file_ext": p.file_ext or "",
|
||||||
|
"summary": p.summary or "",
|
||||||
|
"keywords": content_meta.get("keywords", []),
|
||||||
|
"topic": content_meta.get("topic", ""),
|
||||||
|
"domain": content_meta.get("domain", ""),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
"summary_embedding": summary_embedding,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await connector.execute_query(
|
||||||
|
PERCEPTUAL_NODE_SAVE,
|
||||||
|
perceptuals=flattened,
|
||||||
|
)
|
||||||
|
created_uuids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
||||||
|
return created_uuids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def add_perceptual_dialogue_edges(
|
||||||
|
perceptuals: list,
|
||||||
|
dialog_id: str,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
) -> Optional[List[str]]:
|
||||||
|
"""Add edges between Perceptual nodes and Dialogue nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
perceptuals: List of MemoryPerceptualModel objects
|
||||||
|
dialog_id: The dialogue ID (or ref_id) to link to
|
||||||
|
connector: Neo4j connector instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created edge element IDs or None if failed
|
||||||
|
"""
|
||||||
|
if not perceptuals or not dialog_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
edges = []
|
||||||
|
for p in perceptuals:
|
||||||
|
edges.append({
|
||||||
|
"perceptual_id": str(p.id),
|
||||||
|
"dialog_id": dialog_id,
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await connector.execute_query(
|
||||||
|
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
||||||
|
edges=edges,
|
||||||
|
)
|
||||||
|
created_ids = [record.get("uuid") for record in result]
|
||||||
|
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
||||||
|
return created_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1323,3 +1323,36 @@ RETURN s.statement AS statement,
|
|||||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 感知记忆节点保存
|
||||||
|
PERCEPTUAL_NODE_SAVE = """
|
||||||
|
UNWIND $perceptuals AS p
|
||||||
|
MERGE (n:Perceptual {id: p.id})
|
||||||
|
SET n += {
|
||||||
|
id: p.id,
|
||||||
|
end_user_id: p.end_user_id,
|
||||||
|
perceptual_type: p.perceptual_type,
|
||||||
|
file_path: p.file_path,
|
||||||
|
file_name: p.file_name,
|
||||||
|
file_ext: p.file_ext,
|
||||||
|
summary: p.summary,
|
||||||
|
keywords: p.keywords,
|
||||||
|
topic: p.topic,
|
||||||
|
domain: p.domain,
|
||||||
|
created_at: p.created_at,
|
||||||
|
summary_embedding: p.summary_embedding
|
||||||
|
}
|
||||||
|
RETURN n.id AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 感知记忆与对话的关联边
|
||||||
|
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||||
|
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
||||||
|
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
||||||
|
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
||||||
|
SET r.end_user_id = edge.end_user_id,
|
||||||
|
r.created_at = edge.created_at
|
||||||
|
RETURN elementId(r) AS uuid
|
||||||
|
"""
|
||||||
|
|||||||
@@ -387,6 +387,12 @@ class MemoryConfig:
|
|||||||
|
|
||||||
rerank_model_id: Optional[UUID] = None
|
rerank_model_id: Optional[UUID] = None
|
||||||
rerank_model_name: Optional[str] = None
|
rerank_model_name: Optional[str] = None
|
||||||
|
video_model_id: Optional[UUID] = None
|
||||||
|
video_model_name: Optional[str] = None
|
||||||
|
vision_model_id: Optional[UUID] = None
|
||||||
|
vision_model_name: Optional[str] = None
|
||||||
|
audio_model_id: Optional[UUID] = None
|
||||||
|
audio_model_name: Optional[str] = None
|
||||||
|
|
||||||
llm_params: Dict[str, Any] = field(default_factory=dict)
|
llm_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
embedding_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class AppChatService:
|
|||||||
model_type=ModelType.LLM
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 调用 Agent(支持多模态)
|
# 调用 Agent(支持多模态)
|
||||||
@@ -339,7 +339,7 @@ class AppChatService:
|
|||||||
model_type=ModelType.LLM
|
model_type=ModelType.LLM
|
||||||
)
|
)
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||||
|
|||||||
@@ -600,7 +600,7 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
@@ -836,7 +836,7 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
|
|||||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.cache import InterestMemoryCache
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.messages_tools import (
|
from app.core.memory.agent.utils.messages_tools import (
|
||||||
merge_multiple_search_results,
|
merge_multiple_search_results,
|
||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
|
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||||
@@ -271,6 +274,7 @@ class MemoryAgentService:
|
|||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
|
file_messages: list[dict],
|
||||||
config_id: Optional[uuid.UUID] | int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
@@ -283,6 +287,7 @@ class MemoryAgentService:
|
|||||||
Args:
|
Args:
|
||||||
end_user_id: Group identifier (also used as end_user_id)
|
end_user_id: Group identifier (also used as end_user_id)
|
||||||
messages: Message to write
|
messages: Message to write
|
||||||
|
files: Files to write
|
||||||
config_id: Configuration ID from database
|
config_id: Configuration ID from database
|
||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
@@ -342,48 +347,52 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
perceptual_serivce = MemoryPerceptualService(db)
|
||||||
|
file_content = []
|
||||||
|
for message in file_messages:
|
||||||
|
for file in message["files"]:
|
||||||
|
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
file=FileInput(**file)
|
||||||
|
)
|
||||||
|
file_content.append(file_object)
|
||||||
|
|
||||||
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag":
|
if storage_type == "rag":
|
||||||
# For RAG storage, convert messages to single string
|
# For RAG storage, convert messages to single string
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
return "success"
|
return "success"
|
||||||
else:
|
else:
|
||||||
async with make_write_graph() as graph:
|
await write_neo4j(
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
end_user_id=end_user_id,
|
||||||
# Convert structured messages to LangChain messages
|
messages=messages,
|
||||||
langchain_messages = []
|
file_content=file_content,
|
||||||
for msg in messages:
|
memory_config=memory_config,
|
||||||
if msg['role'] == 'user':
|
ref_id='',
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
language=language
|
||||||
elif msg['role'] == 'assistant':
|
)
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
for lang in ["zh", "en"]:
|
||||||
print(100 * '-')
|
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||||
print(langchain_messages)
|
end_user_id, lang
|
||||||
print(100 * '-')
|
)
|
||||||
# 初始状态 - 包含所有必要字段
|
if deleted:
|
||||||
initial_state = {
|
logger.info(
|
||||||
"messages": langchain_messages,
|
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||||
"end_user_id": end_user_id,
|
return self.writer_messages_deal(
|
||||||
"memory_config": memory_config,
|
"success",
|
||||||
"language": language
|
start_time,
|
||||||
}
|
end_user_id,
|
||||||
|
config_id,
|
||||||
# 获取节点更新信息
|
message_text,
|
||||||
async for update_event in graph.astream(
|
{
|
||||||
initial_state,
|
"status": "success",
|
||||||
stream_mode="updates",
|
"data": messages,
|
||||||
config=config
|
"config_id": memory_config.config_id,
|
||||||
):
|
"config_name": memory_config.config_name
|
||||||
for node_name, node_data in update_event.items():
|
}
|
||||||
if 'save_neo4j' == node_name:
|
)
|
||||||
massages = node_data
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
# Convert messages back to string for logging
|
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
|
||||||
contents)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class MemoryAPIService:
|
|||||||
2. Maps end_user_id to end_user_id for memory operations
|
2. Maps end_user_id to end_user_id for memory operations
|
||||||
3. Delegates to MemoryAgentService for actual memory read/write operations
|
3. Delegates to MemoryAgentService for actual memory read/write operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
"""Initialize MemoryAPIService.
|
"""Initialize MemoryAPIService.
|
||||||
|
|
||||||
@@ -36,11 +36,11 @@ class MemoryAPIService:
|
|||||||
db: SQLAlchemy database session
|
db: SQLAlchemy database session
|
||||||
"""
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def validate_end_user(
|
def validate_end_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> EndUser:
|
) -> EndUser:
|
||||||
"""Validate that end_user exists and belongs to the workspace.
|
"""Validate that end_user exists and belongs to the workspace.
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ class MemoryAPIService:
|
|||||||
BusinessException: If end_user not in authorized workspace
|
BusinessException: If end_user not in authorized workspace
|
||||||
"""
|
"""
|
||||||
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
|
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
|
||||||
|
|
||||||
# Query end_user by ID
|
# Query end_user by ID
|
||||||
try:
|
try:
|
||||||
end_user_uuid = uuid.UUID(end_user_id)
|
end_user_uuid = uuid.UUID(end_user_id)
|
||||||
@@ -66,7 +66,7 @@ class MemoryAPIService:
|
|||||||
message=f"Invalid end_user_id format: {end_user_id}",
|
message=f"Invalid end_user_id format: {end_user_id}",
|
||||||
code=BizCode.INVALID_PARAMETER
|
code=BizCode.INVALID_PARAMETER
|
||||||
)
|
)
|
||||||
|
|
||||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
|
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
|
||||||
|
|
||||||
if not end_user:
|
if not end_user:
|
||||||
@@ -75,13 +75,13 @@ class MemoryAPIService:
|
|||||||
resource_type="EndUser",
|
resource_type="EndUser",
|
||||||
resource_id=end_user_id
|
resource_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify end_user belongs to the workspace via App relationship
|
# Verify end_user belongs to the workspace via App relationship
|
||||||
app = self.db.query(App).filter(
|
app = self.db.query(App).filter(
|
||||||
App.id == end_user.app_id,
|
App.id == end_user.app_id,
|
||||||
App.is_active.is_(True)
|
App.is_active.is_(True)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found for end_user: {end_user_id}")
|
logger.warning(f"App not found for end_user: {end_user_id}")
|
||||||
# raise ResourceNotFoundException(
|
# raise ResourceNotFoundException(
|
||||||
@@ -99,7 +99,7 @@ class MemoryAPIService:
|
|||||||
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
||||||
# code=BizCode.FORBIDDEN
|
# code=BizCode.FORBIDDEN
|
||||||
# )
|
# )
|
||||||
|
|
||||||
logger.info(f"End user {end_user_id} validated successfully")
|
logger.info(f"End user {end_user_id} validated successfully")
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
@@ -125,13 +125,14 @@ class MemoryAPIService:
|
|||||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
files: Optional[list]=None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Write memory with validation.
|
||||||
|
|
||||||
@@ -153,14 +154,16 @@ class MemoryAPIService:
|
|||||||
ResourceNotFoundException: If end_user not found
|
ResourceNotFoundException: If end_user not found
|
||||||
BusinessException: If end_user not in authorized workspace or write fails
|
BusinessException: If end_user not in authorized workspace or write fails
|
||||||
"""
|
"""
|
||||||
|
if files is None:
|
||||||
|
files = list()
|
||||||
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Update end user's memory_config_id
|
# Update end user's memory_config_id
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||||
@@ -171,11 +174,12 @@ class MemoryAPIService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id or ""
|
user_rag_memory_id=user_rag_memory_id or "",
|
||||||
|
files=files
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||||
|
|
||||||
# result may be a string "success" or a dict with a "status" key
|
# result may be a string "success" or a dict with a "status" key
|
||||||
# Preserve the full dict so callers don't silently lose extra fields
|
# Preserve the full dict so callers don't silently lose extra fields
|
||||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||||
@@ -189,7 +193,7 @@ class MemoryAPIService:
|
|||||||
"status": result if isinstance(result, str) else "success",
|
"status": result if isinstance(result, str) else "success",
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
@@ -204,16 +208,16 @@ class MemoryAPIService:
|
|||||||
message=f"Memory write failed: {str(e)}",
|
message=f"Memory write failed: {str(e)}",
|
||||||
code=BizCode.MEMORY_WRITE_FAILED
|
code=BizCode.MEMORY_WRITE_FAILED
|
||||||
)
|
)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
search_switch: str = "0",
|
search_switch: str = "0",
|
||||||
config_id: str = "",
|
config_id: str = "",
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory with validation.
|
||||||
|
|
||||||
@@ -237,14 +241,13 @@ class MemoryAPIService:
|
|||||||
BusinessException: If end_user not in authorized workspace or read fails
|
BusinessException: If end_user not in authorized workspace or read fails
|
||||||
"""
|
"""
|
||||||
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||||
|
|
||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Update end user's memory_config_id
|
# Update end user's memory_config_id
|
||||||
self._update_end_user_config(end_user_id, config_id)
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().read_memory(
|
result = await MemoryAgentService().read_memory(
|
||||||
@@ -257,15 +260,15 @@ class MemoryAPIService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id or ""
|
user_rag_memory_id=user_rag_memory_id or ""
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {end_user_id}")
|
logger.info(f"Memory read successful for end_user: {end_user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": result.get("answer", ""),
|
"answer": result.get("answer", ""),
|
||||||
"intermediate_outputs": result.get("intermediate_outputs", []),
|
"intermediate_outputs": result.get("intermediate_outputs", []),
|
||||||
"end_user_id": end_user_id
|
"end_user_id": end_user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
@@ -282,8 +285,8 @@ class MemoryAPIService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def list_memory_configs(
|
def list_memory_configs(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""List all memory configs for a workspace.
|
"""List all memory configs for a workspace.
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
|
|||||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||||
if isinstance(config_id, uuid.UUID):
|
if isinstance(config_id, uuid.UUID):
|
||||||
return config_id
|
return config_id
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
"Configuration ID cannot be None",
|
"Configuration ID cannot be None",
|
||||||
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
|
|||||||
if result:
|
if result:
|
||||||
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
|
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
|
||||||
return result.config_id
|
return result.config_id
|
||||||
|
|
||||||
return config_id
|
return config_id
|
||||||
|
|
||||||
if isinstance(config_id, str):
|
if isinstance(config_id, str):
|
||||||
config_id_stripped = config_id.strip()
|
config_id_stripped = config_id.strip()
|
||||||
|
|
||||||
# Try parsing as UUID first
|
# Try parsing as UUID first
|
||||||
try:
|
try:
|
||||||
return uuid.UUID(config_id_stripped)
|
return uuid.UUID(config_id_stripped)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Fall back to integer parsing
|
# Fall back to integer parsing
|
||||||
try:
|
try:
|
||||||
parsed_id = int(config_id_stripped)
|
parsed_id = int(config_id_stripped)
|
||||||
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
|
|||||||
field_name="config_id",
|
field_name="config_id",
|
||||||
invalid_value=config_id,
|
invalid_value=config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
|
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
|
||||||
if db is not None:
|
if db is not None:
|
||||||
# 查询 user_id 匹配的记录
|
# 查询 user_id 匹配的记录
|
||||||
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
|
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
|
||||||
result = db.execute(stmt).scalars().first()
|
result = db.execute(stmt).scalars().first()
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
|
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
|
||||||
return result.config_id
|
return result.config_id
|
||||||
|
|
||||||
return parsed_id
|
return parsed_id
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def load_memory_config(
|
def load_memory_config(
|
||||||
self,
|
self,
|
||||||
config_id: Optional[UUID] = None,
|
config_id: Optional[UUID] = None,
|
||||||
workspace_id: Optional[UUID] = None,
|
workspace_id: Optional[UUID] = None,
|
||||||
service_name: str = "MemoryConfigService",
|
service_name: str = "MemoryConfigService",
|
||||||
) -> MemoryConfig:
|
) -> MemoryConfig:
|
||||||
"""
|
"""
|
||||||
Load memory configuration from database with optional fallback.
|
Load memory configuration from database with optional fallback.
|
||||||
@@ -194,14 +194,14 @@ class MemoryConfigService:
|
|||||||
try:
|
try:
|
||||||
# Use get_config_with_fallback if workspace_id is provided
|
# Use get_config_with_fallback if workspace_id is provided
|
||||||
memory_config = None
|
memory_config = None
|
||||||
|
validated_config_id = None
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
validated_config_id = None
|
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
validated_config_id = _validate_config_id(config_id, self.db)
|
validated_config_id = _validate_config_id(config_id, self.db)
|
||||||
except Exception:
|
except Exception:
|
||||||
validated_config_id = None
|
validated_config_id = None
|
||||||
|
|
||||||
memory_config = self.get_config_with_fallback(
|
memory_config = self.get_config_with_fallback(
|
||||||
memory_config_id=validated_config_id,
|
memory_config_id=validated_config_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
@@ -210,7 +210,7 @@ class MemoryConfigService:
|
|||||||
validated_config_id = _validate_config_id(config_id, self.db)
|
validated_config_id = _validate_config_id(config_id, self.db)
|
||||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||||
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
|
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
|
||||||
|
|
||||||
if not memory_config:
|
if not memory_config:
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
config_logger.error(
|
config_logger.error(
|
||||||
@@ -233,7 +233,7 @@ class MemoryConfigService:
|
|||||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
|
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
|
||||||
db_query_time = time.time() - db_query_start
|
db_query_time = time.time() - db_query_start
|
||||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
raise ConfigurationError(
|
raise ConfigurationError(
|
||||||
f"Workspace not found for config {memory_config.config_id}"
|
f"Workspace not found for config {memory_config.config_id}"
|
||||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
|||||||
|
|
||||||
# Helper function to validate model with workspace fallback
|
# Helper function to validate model with workspace fallback
|
||||||
def _validate_model_with_fallback(
|
def _validate_model_with_fallback(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
workspace_default: str,
|
workspace_default: str,
|
||||||
required: bool = False
|
required: bool = False
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Validate model ID, falling back to workspace default if invalid.
|
"""Validate model ID, falling back to workspace default if invalid.
|
||||||
|
|
||||||
@@ -275,7 +275,7 @@ class MemoryConfigService:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"{model_type} model validation failed, trying workspace default: {e}"
|
f"{model_type} model validation failed, trying workspace default: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fallback to workspace default
|
# Fallback to workspace default
|
||||||
if workspace_default:
|
if workspace_default:
|
||||||
try:
|
try:
|
||||||
@@ -297,7 +297,7 @@ class MemoryConfigService:
|
|||||||
logger.error(f"Workspace default {model_type} model also invalid: {e}")
|
logger.error(f"Workspace default {model_type} model also invalid: {e}")
|
||||||
if required:
|
if required:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if required:
|
if required:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
f"{model_type.title()} model is required but not configured",
|
f"{model_type.title()} model is required but not configured",
|
||||||
@@ -306,7 +306,7 @@ class MemoryConfigService:
|
|||||||
config_id=validated_config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id
|
workspace_id=workspace.id
|
||||||
)
|
)
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Step 2: Validate embedding model with workspace fallback
|
# Step 2: Validate embedding model with workspace fallback
|
||||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
|||||||
if memory_config.rerank_id or workspace.rerank:
|
if memory_config.rerank_id or workspace.rerank:
|
||||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||||
|
|
||||||
|
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.vision_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.audio_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_uuid, video_name = validate_and_resolve_model_id(
|
||||||
|
memory_config.video_id,
|
||||||
|
"llm",
|
||||||
|
self.db,
|
||||||
|
workspace.tenant_id,
|
||||||
|
required=False,
|
||||||
|
config_id=validated_config_id,
|
||||||
|
workspace_id=workspace.id,
|
||||||
|
)
|
||||||
# Create immutable MemoryConfig object
|
# Create immutable MemoryConfig object
|
||||||
config = MemoryConfig(
|
config = MemoryConfig(
|
||||||
config_id=memory_config.config_id,
|
config_id=memory_config.config_id,
|
||||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
|||||||
embedding_model_name=embedding_name,
|
embedding_model_name=embedding_name,
|
||||||
rerank_model_id=rerank_uuid,
|
rerank_model_id=rerank_uuid,
|
||||||
rerank_model_name=rerank_name,
|
rerank_model_name=rerank_name,
|
||||||
|
video_model_id=video_uuid,
|
||||||
|
video_model_name=video_name,
|
||||||
|
vision_model_id=vision_uuid,
|
||||||
|
vision_model_name=vision_name,
|
||||||
|
audio_model_id=audio_uuid,
|
||||||
|
audio_model_name=audio_name,
|
||||||
storage_type=workspace.storage_type or "neo4j",
|
storage_type=workspace.storage_type or "neo4j",
|
||||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
|||||||
reflexion_baseline=memory_config.baseline or "Time",
|
reflexion_baseline=memory_config.baseline or "Time",
|
||||||
loaded_at=datetime.now(),
|
loaded_at=datetime.now(),
|
||||||
# Pipeline config: Deduplication
|
# Pipeline config: Deduplication
|
||||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
enable_llm_dedup_blockwise=bool(
|
||||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||||
|
enable_llm_disambiguation=bool(
|
||||||
|
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||||
# Pipeline config: Statement extraction
|
# Pipeline config: Statement extraction
|
||||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
statement_granularity=int(
|
||||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
include_dialogue_context=bool(
|
||||||
|
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||||
|
max_dialogue_context_chars=int(
|
||||||
|
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||||
# Pipeline config: Forgetting engine
|
# Pipeline config: Forgetting engine
|
||||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||||
# Pipeline config: Pruning
|
# Pipeline config: Pruning
|
||||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
pruning_enabled=bool(
|
||||||
|
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
pruning_threshold=float(
|
||||||
|
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||||
@@ -448,9 +490,9 @@ class MemoryConfigService:
|
|||||||
if not config:
|
if not config:
|
||||||
logger.warning(f"Model ID {model_id} not found")
|
logger.warning(f"Model ID {model_id} not found")
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||||
|
|
||||||
api_config: ModelApiKey = config.api_keys[0]
|
api_config: ModelApiKey = config.api_keys[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": api_config.model_name,
|
"model_name": api_config.model_name,
|
||||||
"provider": api_config.provider,
|
"provider": api_config.provider,
|
||||||
@@ -481,9 +523,9 @@ class MemoryConfigService:
|
|||||||
if not config:
|
if not config:
|
||||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||||
|
|
||||||
api_config: ModelApiKey = config.api_keys[0]
|
api_config: ModelApiKey = config.api_keys[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": api_config.model_name,
|
"model_name": api_config.model_name,
|
||||||
"provider": api_config.provider,
|
"provider": api_config.provider,
|
||||||
@@ -571,25 +613,25 @@ class MemoryConfigService:
|
|||||||
"""
|
"""
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||||
|
|
||||||
if not memory_config.scene_id:
|
if not memory_config.scene_id:
|
||||||
logger.debug("No scene_id configured, skipping ontology type fetch")
|
logger.debug("No scene_id configured, skipping ontology type fetch")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ontology_repo = OntologyClassRepository(self.db)
|
ontology_repo = OntologyClassRepository(self.db)
|
||||||
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
|
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
|
||||||
|
|
||||||
if not ontology_classes:
|
if not ontology_classes:
|
||||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||||
)
|
)
|
||||||
return ontology_types
|
return ontology_types
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_workspace_default_config(
|
def get_workspace_default_config(
|
||||||
self,
|
self,
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get workspace default memory config.
|
"""Get workspace default memory config.
|
||||||
|
|
||||||
@@ -613,19 +655,19 @@ class MemoryConfigService:
|
|||||||
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
||||||
"""
|
"""
|
||||||
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
|
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No active memory config found for workspace fallback",
|
"No active memory config found for workspace fallback",
|
||||||
extra={"workspace_id": str(workspace_id)}
|
extra={"workspace_id": str(workspace_id)}
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def get_config_with_fallback(
|
def get_config_with_fallback(
|
||||||
self,
|
self,
|
||||||
memory_config_id: Optional[UUID],
|
memory_config_id: Optional[UUID],
|
||||||
workspace_id: UUID
|
workspace_id: UUID
|
||||||
) -> Optional["MemoryConfigModel"]:
|
) -> Optional["MemoryConfigModel"]:
|
||||||
"""Get memory config with fallback to workspace default.
|
"""Get memory config with fallback to workspace default.
|
||||||
|
|
||||||
@@ -644,13 +686,13 @@ class MemoryConfigService:
|
|||||||
"No memory config ID provided, using workspace default",
|
"No memory config ID provided, using workspace default",
|
||||||
extra={"workspace_id": str(workspace_id)}
|
extra={"workspace_id": str(workspace_id)}
|
||||||
)
|
)
|
||||||
|
|
||||||
config = MemoryConfigRepository.get_with_fallback(
|
config = MemoryConfigRepository.get_with_fallback(
|
||||||
self.db,
|
self.db,
|
||||||
memory_config_id,
|
memory_config_id,
|
||||||
workspace_id
|
workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not config and memory_config_id:
|
if not config and memory_config_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Memory config not found, falling back to workspace default",
|
"Memory config not found, falling back to workspace default",
|
||||||
@@ -659,13 +701,13 @@ class MemoryConfigService:
|
|||||||
"workspace_id": str(workspace_id)
|
"workspace_id": str(workspace_id)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def delete_config(
|
def delete_config(
|
||||||
self,
|
self,
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
force: bool = False
|
force: bool = False
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Delete memory config with protection against in-use configs.
|
"""Delete memory config with protection against in-use configs.
|
||||||
|
|
||||||
@@ -687,7 +729,7 @@ class MemoryConfigService:
|
|||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
# 处理旧格式 int 类型的 config_id
|
# 处理旧格式 int 类型的 config_id
|
||||||
if isinstance(config_id, int):
|
if isinstance(config_id, int):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -699,11 +741,11 @@ class MemoryConfigService:
|
|||||||
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
|
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
|
||||||
"legacy_int_id": config_id
|
"legacy_int_id": config_id
|
||||||
}
|
}
|
||||||
|
|
||||||
config = self.db.get(MemoryConfigModel, config_id)
|
config = self.db.get(MemoryConfigModel, config_id)
|
||||||
if not config:
|
if not config:
|
||||||
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
||||||
|
|
||||||
# Check if this is the default config - default configs cannot be deleted
|
# Check if this is the default config - default configs cannot be deleted
|
||||||
if config.is_default:
|
if config.is_default:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -715,11 +757,11 @@ class MemoryConfigService:
|
|||||||
"message": "默认配置不允许删除",
|
"message": "默认配置不允许删除",
|
||||||
"is_default": True
|
"is_default": True
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use repository to count connected end users
|
# Use repository to count connected end users
|
||||||
end_user_repo = EndUserRepository(self.db)
|
end_user_repo = EndUserRepository(self.db)
|
||||||
connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
||||||
|
|
||||||
if connected_count > 0 and not force:
|
if connected_count > 0 and not force:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Attempted to delete memory config with connected end users",
|
"Attempted to delete memory config with connected end users",
|
||||||
@@ -728,18 +770,18 @@ class MemoryConfigService:
|
|||||||
"connected_count": connected_count
|
"connected_count": connected_count
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "warning",
|
"status": "warning",
|
||||||
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
|
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
|
||||||
"connected_count": connected_count,
|
"connected_count": connected_count,
|
||||||
"force_required": True
|
"force_required": True
|
||||||
}
|
}
|
||||||
|
|
||||||
# Force delete: use repository to clear end user references first
|
# Force delete: use repository to clear end user references first
|
||||||
if connected_count > 0 and force:
|
if connected_count > 0 and force:
|
||||||
cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Force deleting memory config, clearing end user references",
|
"Force deleting memory config, clearing end user references",
|
||||||
extra={
|
extra={
|
||||||
@@ -747,11 +789,11 @@ class MemoryConfigService:
|
|||||||
"cleared_end_users": cleared_count
|
"cleared_end_users": cleared_count
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.db.delete(config)
|
self.db.delete(config)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Memory config deleted",
|
"Memory config deleted",
|
||||||
extra={
|
extra={
|
||||||
@@ -760,16 +802,16 @@ class MemoryConfigService:
|
|||||||
"affected_users": connected_count
|
"affected_users": connected_count
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "记忆配置删除成功",
|
"message": "记忆配置删除成功",
|
||||||
"affected_users": connected_count
|
"affected_users": connected_count
|
||||||
}
|
}
|
||||||
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
self.db.rollback()
|
self.db.rollback()
|
||||||
|
|
||||||
# Handle foreign key violation gracefully
|
# Handle foreign key violation gracefully
|
||||||
error_str = str(e.orig) if e.orig else str(e)
|
error_str = str(e.orig) if e.orig else str(e)
|
||||||
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
|
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
|
||||||
@@ -785,7 +827,7 @@ class MemoryConfigService:
|
|||||||
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
|
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
|
||||||
"force_required": True
|
"force_required": True
|
||||||
}
|
}
|
||||||
|
|
||||||
# Re-raise other integrity errors
|
# Re-raise other integrity errors
|
||||||
logger.error(
|
logger.error(
|
||||||
"Delete failed due to integrity error",
|
"Delete failed due to integrity error",
|
||||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
|||||||
# ==================== 记忆配置提取方法 ====================
|
# ==================== 记忆配置提取方法 ====================
|
||||||
|
|
||||||
def extract_memory_config_id(
|
def extract_memory_config_id(
|
||||||
self,
|
self,
|
||||||
app_type: str,
|
app_type: str,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||||
|
|
||||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_agent(
|
def _extract_memory_config_id_from_agent(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Agent 应用配置中提取 memory_config_id
|
"""从 Agent 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
|||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
def _extract_memory_config_id_from_workflow(
|
def _extract_memory_config_id_from_workflow(
|
||||||
self,
|
self,
|
||||||
config: dict
|
config: dict
|
||||||
) -> tuple[Optional[uuid.UUID], bool]:
|
) -> tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从 Workflow 应用配置中提取 memory_config_id
|
"""从 Workflow 应用配置中提取 memory_config_id
|
||||||
|
|
||||||
@@ -905,14 +947,14 @@ class MemoryConfigService:
|
|||||||
- is_legacy_int: 是否检测到旧格式 int 数据
|
- is_legacy_int: 是否检测到旧格式 int 数据
|
||||||
"""
|
"""
|
||||||
nodes = config.get("nodes", [])
|
nodes = config.get("nodes", [])
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_type = node.get("type", "")
|
node_type = node.get("type", "")
|
||||||
|
|
||||||
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
|
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
|
||||||
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
|
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
|
||||||
config_id = node.get("config", {}).get("config_id")
|
config_id = node.get("config", {}).get("config_id")
|
||||||
|
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
|
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
|
||||||
@@ -937,6 +979,6 @@ class MemoryConfigService:
|
|||||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||||
f"node_type={node_type}, error={str(e)}"
|
f"node_type={node_type}, error={str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("工作流配置中未找到记忆节点")
|
logger.debug("工作流配置中未找到记忆节点")
|
||||||
return None, False
|
return None, False
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models import FileMetadata
|
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType, FileInput
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
from app.schemas.memory_perceptual_schema import (
|
from app.schemas.memory_perceptual_schema import (
|
||||||
PerceptualQuerySchema,
|
PerceptualQuerySchema,
|
||||||
PerceptualTimelineResponse,
|
PerceptualTimelineResponse,
|
||||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
|||||||
AudioModal, Content, VideoModal, TextModal
|
AudioModal, Content, VideoModal, TextModal
|
||||||
)
|
)
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
|||||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||||
|
|
||||||
|
def _get_mutlimodal_client(
|
||||||
|
self,
|
||||||
|
file_type: FileType,
|
||||||
|
config: MemoryConfig
|
||||||
|
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||||
|
model_config = None
|
||||||
|
if file_type == FileType.AUDIO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.audio_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.VIDEO:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.video_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.DOCUMENT:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.llm_model_id
|
||||||
|
)
|
||||||
|
elif file_type == FileType.IMAGE:
|
||||||
|
model_config = ModelApiKeyService.get_available_api_key(
|
||||||
|
self.db,
|
||||||
|
config.vision_model_id
|
||||||
|
)
|
||||||
|
llm = None
|
||||||
|
if model_config:
|
||||||
|
llm = RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_config.model_name,
|
||||||
|
provider=model_config.provider,
|
||||||
|
api_key=model_config.api_key,
|
||||||
|
base_url=model_config.api_base,
|
||||||
|
is_omni=model_config.is_omni
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return llm, model_config
|
||||||
|
|
||||||
async def generate_perceptual_memory(
|
async def generate_perceptual_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
model_config: ModelInfo,
|
memory_config: MemoryConfig,
|
||||||
file_type: str,
|
file: FileInput
|
||||||
file_url: str,
|
|
||||||
file_message: dict,
|
|
||||||
):
|
):
|
||||||
memories = self.repository.get_by_url(file_url)
|
memories = self.repository.get_by_url(file.url)
|
||||||
if memories:
|
if memories:
|
||||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||||
memory_cache = memories[0]
|
memory_cache = memories[0]
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||||
file_path=memory_cache.file_path,
|
file_path=memory_cache.file_path,
|
||||||
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
|
|||||||
meta_data=memory_cache.meta_data
|
meta_data=memory_cache.meta_data
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
return
|
else:
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
for memory in memories:
|
||||||
|
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||||
|
return memory
|
||||||
|
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||||
|
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||||
model_name=model_config.model_name,
|
model_name=model_config.model_name,
|
||||||
provider=model_config.provider,
|
provider=model_config.provider,
|
||||||
api_key=model_config.api_key,
|
api_key=model_config.api_key,
|
||||||
base_url=model_config.api_base,
|
api_base=model_config.api_base,
|
||||||
is_omni=model_config.is_omni
|
is_omni=model_config.is_omni,
|
||||||
), type=model_config.model_type)
|
capability=model_config.capability,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
))
|
||||||
|
file_message = await multimodel_service.process_files(
|
||||||
|
files=[file]
|
||||||
|
)
|
||||||
|
if file_message:
|
||||||
|
file_message = file_message[0]
|
||||||
try:
|
try:
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||||
messages = [
|
messages = [
|
||||||
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
|
|||||||
]}
|
]}
|
||||||
]
|
]
|
||||||
result = await llm.ainvoke(messages)
|
result = await llm.ainvoke(messages)
|
||||||
content = json_repair.repair_json(result.content, return_objects=True)
|
content = result.content
|
||||||
path = urlparse(file_url).path
|
final_output = ""
|
||||||
|
if isinstance(content, list):
|
||||||
|
for msg in content:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
final_output += msg.get("text", "")
|
||||||
|
elif isinstance(msg, str):
|
||||||
|
final_output += msg
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
final_output += content.get("text", "")
|
||||||
|
elif isinstance(content, str):
|
||||||
|
final_output = content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||||
|
content = json_repair.repair_json(final_output, return_objects=True)
|
||||||
|
path = urlparse(file.url).path
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename = unquote(filename)
|
filename = unquote(filename)
|
||||||
file_ext = os.path.splitext(filename)[1]
|
file_ext = os.path.splitext(filename)[1]
|
||||||
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
business_logger.debug(f"Remote file, file_id={filename}")
|
business_logger.debug(f"Remote file, file_id={filename}")
|
||||||
if not file_ext:
|
if not file_ext:
|
||||||
if file_type == FileType.AUDIO:
|
if file.type == FileType.AUDIO:
|
||||||
file_ext = ".mp3"
|
file_ext = ".mp3"
|
||||||
elif file_type == FileType.VIDEO:
|
elif file.type == FileType.VIDEO:
|
||||||
file_ext = ".mp4"
|
file_ext = ".mp4"
|
||||||
elif file_type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
file_ext = ".txt"
|
file_ext = ".txt"
|
||||||
elif file_type == FileType.IMAGE:
|
elif file.type == FileType.IMAGE:
|
||||||
file_ext = ".jpg"
|
file_ext = ".jpg"
|
||||||
filename += file_ext
|
filename += file_ext
|
||||||
file_content = {
|
file_content = {
|
||||||
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
|
|||||||
"topic": content.get("topic"),
|
"topic": content.get("topic"),
|
||||||
"domain": content.get("domain")
|
"domain": content.get("domain")
|
||||||
}
|
}
|
||||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"scene": content.get("scene", [])
|
"scene": content.get("scene", [])
|
||||||
}
|
}
|
||||||
elif file_type in [FileType.DOCUMENT]:
|
elif file.type in [FileType.DOCUMENT]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"section_count": content.get("section_count", 0),
|
"section_count": content.get("section_count", 0),
|
||||||
"title": content.get("title", ""),
|
"title": content.get("title", ""),
|
||||||
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
|
|||||||
file_modalities = {
|
file_modalities = {
|
||||||
"speaker_count": content.get("speaker_count", 0)
|
"speaker_count": content.get("speaker_count", 0)
|
||||||
}
|
}
|
||||||
self.repository.create_perceptual_memory(
|
memory = self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||||
file_path=file_url,
|
file_path=file.url,
|
||||||
file_name=filename,
|
file_name=filename,
|
||||||
file_ext=file_ext,
|
file_ext=file_ext,
|
||||||
summary=content.get('summary', ""),
|
summary=content.get('summary', ""),
|
||||||
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
return memory
|
||||||
|
|||||||
@@ -9,14 +9,12 @@
|
|||||||
- OpenAI: 支持 URL 和 base64 格式
|
- OpenAI: 支持 URL 和 base64 格式
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
|
import csv
|
||||||
import io
|
import io
|
||||||
import uuid
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import httpx
|
import httpx
|
||||||
import magic
|
import magic
|
||||||
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
|
|||||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||||
from app.tasks import write_perceptual_memory
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
|||||||
|
|
||||||
async def process_files(
|
async def process_files(
|
||||||
self,
|
self,
|
||||||
end_user_id: uuid.UUID | str,
|
|
||||||
files: Optional[List[FileInput]],
|
files: Optional[List[FileInput]],
|
||||||
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理文件列表,返回 LLM 可用的格式
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 用户ID
|
|
||||||
files: 文件输入列表
|
files: 文件输入列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
|||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
if isinstance(end_user_id, uuid.UUID):
|
|
||||||
end_user_id = str(end_user_id)
|
|
||||||
|
|
||||||
# 获取对应的策略
|
# 获取对应的策略
|
||||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
|||||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||||
is_support, content = await self._process_image(file, strategy)
|
is_support, content = await self._process_image(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
is_support, content = await self._process_document(file, strategy)
|
is_support, content = await self._process_document(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
is_support, content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||||
is_support, content = await self._process_video(file, strategy)
|
is_support, content = await self._process_video(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if is_support:
|
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"不支持的文件类型: {file.type}")
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -418,17 +402,6 @@ class MultimodalService:
|
|||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def write_perceptual_memory(
|
|
||||||
self,
|
|
||||||
end_user_id: str,
|
|
||||||
file_type: str,
|
|
||||||
file_url: str,
|
|
||||||
file_message: dict
|
|
||||||
):
|
|
||||||
"""写入感知记忆"""
|
|
||||||
if end_user_id and self.api_config:
|
|
||||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理图片文件
|
处理图片文件
|
||||||
|
|||||||
@@ -1080,12 +1080,14 @@ def write_message_task(
|
|||||||
config_id: str | int,
|
config_id: str | int,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str,
|
user_rag_memory_id: str,
|
||||||
|
file_messages: list[dict] | None,
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
|
file_messages: Files to write
|
||||||
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
config_id: Configuration ID (can be UUID string, integer, or config_id_old)
|
||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
user_rag_memory_id: User RAG memory ID
|
user_rag_memory_id: User RAG memory ID
|
||||||
@@ -1097,6 +1099,8 @@ def write_message_task(
|
|||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
|
if file_messages is None:
|
||||||
|
file_messages = []
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||||
@@ -1142,7 +1146,7 @@ def write_message_task(
|
|||||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
|
||||||
user_rag_memory_id, language)
|
user_rag_memory_id, language)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user