feat(memory, model): update multi-modal memory write and model list API

- Adjust multi-modal memory write behavior for text and visual data
- Mask API keys in model list response to prevent exposure
- Add capability-based filtering to the model list API
This commit is contained in:
Eternity
2026-03-24 13:54:15 +08:00
parent 2ff81ba101
commit 6bba574ca6
21 changed files with 389 additions and 401 deletions

View File

@@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import (
StatementNode,
ExtractedEntityNode,
EntityEntityEdge,
PerceptualNode,
PerceptualEdge,
)
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector
):
"""Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes]
@@ -73,8 +78,8 @@ async def save_entities_and_relationships(
async def save_chunk_nodes(
chunk_nodes: List[ChunkNode],
connector: Neo4jConnector
chunk_nodes: List[ChunkNode],
connector: Neo4jConnector
):
"""Save chunk nodes using graph models"""
if not chunk_nodes:
@@ -89,8 +94,8 @@ async def save_chunk_nodes(
async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector
statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector
):
"""Save statement-chunk edges using graph models"""
if not statement_chunk_edges:
@@ -118,8 +123,8 @@ async def save_statement_chunk_edges(
async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
):
"""Save statement-entity edges using graph models"""
if not statement_entity_edges:
@@ -142,7 +147,7 @@ async def save_statement_entity_edges(
if all_se_edges:
try:
await connector.execute_query(
STATEMENT_ENTITY_EDGE_SAVE,
STATEMENT_ENTITY_EDGE_SAVE,
relationships=all_se_edges
)
except Exception:
@@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
perceptual_nodes: List[PerceptualNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
perceptual_edges: List[PerceptualEdge],
connector: Neo4jConnector,
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j(
chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save
perceptual_nodes: List of PerceptualNode objects to save
entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save
perceptual_edges: List of PerceptualEdge objects to save
connector: Neo4j connector instance
Returns:
@@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j(
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
# 2. Save all chunk nodes in batch
if chunk_nodes:
@@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j(
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
if perceptual_nodes:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
perceptual_data = [node.model_dump() for node in perceptual_nodes]
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
perceptual_uuids = [record["uuid"] async for record in result]
results["perceptuals"] = perceptual_uuids
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
# 3. Save all statement nodes in batch
if statement_nodes:
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
@@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j(
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
if perceptual_edges:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
perceptual_edge_data = []
for edge in perceptual_edges:
print(edge.source, edge.target)
perceptual_edge_data.append({
"perceptual_id": edge.source,
"chunk_id": edge.target,
"end_user_id": edge.end_user_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
})
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
perceptual_edges_uuids = [record["uuid"] async for record in result]
results['perceptual_chunk_edges'] = perceptual_edges_uuids
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
return results
try:
@@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j(
def schedule_clustering_after_write(
entity_nodes: List,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
entity_nodes: List,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
写入 Neo4j 成功后,调度后台聚类任务。
@@ -325,14 +358,15 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id))
async def _trigger_clustering(
new_entity_ids: List[str],
end_user_id: str,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
new_entity_ids: List[str],
end_user_id: str,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
聚类触发函数,自动判断全量初始化还是增量更新。