diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 6204a745..71fd41ad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -42,6 +42,7 @@ def get_model_strategies(): @router.get("", response_model=ApiResponse) def get_model_list( type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -74,10 +75,21 @@ def get_model_list( unique_flat_type = list(dict.fromkeys(flat_type)) type_list = [ModelType(t.lower()) for t in unique_flat_type] + capability_list = [] + if capability is not None: + flat_capability = [] + for item in capability: + split_items = [c.strip() for c in item.split(', ') if c.strip()] + flat_capability.extend(split_items) + + unique_flat_capability = list(dict.fromkeys(flat_capability)) + capability_list = unique_flat_capability + api_logger.error(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQuery( type=type_list, provider=provider, + capability=capability_list, is_active=is_active, is_public=is_public, search=search, diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index be796ff9..3ce1df6e 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -5,7 +5,7 @@ from typing import Optional import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends,Header +from fastapi import APIRouter, Depends, Header from app.db import get_db from app.core.language_utils import get_language_from_header @@ -19,7 +19,7 @@ from app.services.user_memory_service import ( analytics_graph_data, analytics_community_graph_data, ) -from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction +from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository @@ -45,9 +45,9 @@ router = APIRouter( @router.get("/analytics/memory_insight/report", response_model=ApiResponse) async def get_memory_insight_report_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的记忆洞察报告 @@ -73,10 +73,10 @@ async def get_memory_insight_report_api( @router.get("/analytics/user_summary", response_model=ApiResponse) async def get_user_summary_api( - end_user_id: str, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 获取缓存的用户摘要 @@ -90,7 +90,7 @@ async def get_user_summary_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -102,7 +102,7 @@ async def get_user_summary_api( api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") try: # 调用服务层获取缓存数据 - result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language) + result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) if result["is_cached"]: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") @@ -117,10 +117,10 @@ async def get_user_summary_api( @router.post("/analytics/generate_cache", response_model=ApiResponse) async def generate_cache_api( - request: GenerateCacheRequest, - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + request: GenerateCacheRequest, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 手动触发缓存生成 @@ -134,7 +134,7 @@ async def generate_cache_api( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -155,10 +155,12 @@ async def generate_cache_api( api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") # 生成记忆洞察 - insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language) + insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, + language=language) # 生成用户摘要 - summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language) + summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, + language=language) # 构建响应 result = { @@ -209,9 +211,9 @@ async def generate_cache_api( @router.get("/analytics/node_statistics", response_model=ApiResponse) async def get_node_statistics_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -220,7 +222,8 @@ async def get_node_statistics_api( api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + api_logger.info( + f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") try: # 调用新的记忆类型统计函数 @@ -228,21 +231,23 @@ async def get_node_statistics_api( # 计算总数用于日志 total_count = sum(item["count"] for item in result) - api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") + api_logger.info( + f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) + @router.get("/analytics/graph_data", response_model=ApiResponse) async def get_graph_data_api( - end_user_id: str, - node_types: Optional[str] = None, - limit: int = 100, - depth: int = 1, - center_node_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -298,9 +303,9 @@ async def get_graph_data_api( @router.get("/analytics/community_graph", response_model=ApiResponse) async def get_community_graph_data_api( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id @@ -334,9 +339,9 @@ async def get_community_graph_data_api( @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( - end_user_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) @@ -385,9 +390,9 @@ async def get_end_user_profile( @router.post("/updated_end_user/profile", response_model=ApiResponse) async def update_end_user_profile( - profile_update: EndUserProfileUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ) -> dict: """ 更新终端用户的基本信息 @@ -417,7 +422,7 @@ async def update_end_user_profile( else: error_msg = result["error"] api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") - + # 根据错误类型映射到合适的业务错误码 if error_msg == "终端用户不存在": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) @@ -427,15 +432,18 @@ async def update_end_user_profile( # 只有未预期的错误才使用 INTERNAL_ERROR return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) + @router.get("/memory_space/timeline_memories", response_model=ApiResponse) -async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): +async def memory_space_timeline_of_shared_memories( + id: str, label: str, + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): # 使用集中化的语言校验 language = get_language_from_header(language_type) - - workspace_id=current_user.current_workspace_id + + workspace_id = current_user.current_workspace_id workspace_repo = WorkspaceRepository(db) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) @@ -447,11 +455,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_ timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) return success(data=timeline_memories_result, msg="共同记忆时间线") + + @router.get("/memory_space/relationship_evolution", response_model=ApiResponse) async def memory_space_relationship_evolution(id: str, label: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ): + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ): try: api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 3b06defe..4c667061 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -11,7 +11,7 @@ async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", end_user_id: str = "group_1", messages: list = None, - ref_id: str = "wyl_20251027", + ref_id: str = "", config_id: str = None ) -> List[DialogData]: """Generate chunks from structured messages using the specified chunker strategy. @@ -40,12 +40,13 @@ async def get_chunked_dialogs( role = msg['role'] content = msg['content'] + files = msg.get("file_content", []) if role not in ['user', 'assistant']: raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") if content.strip(): - conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) if not conversation_messages: raise ValueError("Message list cannot be empty after filtering") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 147a0316..413f54da 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,8 +5,8 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio -import uuid import time +import uuid from datetime import datetime from dotenv import load_dotenv @@ -19,10 +19,8 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context -from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ - add_perceptual_dialogue_edges +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -36,7 +34,6 @@ async def write( end_user_id: str, memory_config: MemoryConfig, messages: list, - file_content: list[MemoryPerceptualModel], ref_id: str = "", language: str = "zh", ) -> None: @@ -47,7 +44,6 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - file_content: mutilmodal message list ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ @@ -142,9 +138,11 @@ async def write( all_chunk_nodes, all_statement_nodes, all_entity_nodes, + all_perceptual_nodes, all_statement_chunk_edges, all_statement_entity_edges, all_entity_entity_edges, + all_perceptual_edges, all_dedup_details, ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) @@ -169,9 +167,11 @@ async def write( chunk_nodes=all_chunk_nodes, statement_nodes=all_statement_nodes, entity_nodes=all_entity_nodes, + perceptual_nodes=all_perceptual_nodes, statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, + perceptual_edges=all_perceptual_edges, connector=neo4j_connector, ) if success: @@ -230,34 +230,6 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) - # Step 5: Save perceptual memory to Neo4j - step_start = time.time() - if file_content: - try: - pc_connector = Neo4jConnector() - try: - created_ids = await add_perceptual_nodes( - perceptuals=file_content, - connector=pc_connector, - embedder_client=embedder_client, - ) - # 如果有 ref_id,建立感知记忆与对话的关联 - if ref_id and created_ids: - await add_perceptual_dialogue_edges( - perceptuals=file_content, - dialog_id=ref_id, - connector=pc_connector, - ) - logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") - finally: - try: - await pc_connector.close() - except Exception: - pass - except Exception as e: - logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) - log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) - # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 93a2df82..51d15aab 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -1,10 +1,10 @@ -from typing import Any, List -import re -import os import asyncio import json -import numpy as np import logging +import os +from typing import Any, List + +import numpy as np # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -246,6 +246,7 @@ class ChunkerClient: "total_sub_chunks": len(sub_chunks), "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) else: @@ -258,6 +259,7 @@ class ChunkerClient: "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, + files=msg.files ) dialogue.chunks.append(chunk) diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index fb251f1f..1b8c9d52 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -114,7 +114,7 @@ class Edge(BaseModel): end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") + expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.") class ChunkEdge(Edge): @@ -175,6 +175,12 @@ class EntityEntityEdge(Edge): return parse_historical_datetime(v) +class PerceptualEdge(Edge): + """Edge connecting perceptual nodes to their source chunks + """ + pass + + class Node(BaseModel): """Base class for all graph nodes in the knowledge graph. @@ -555,19 +561,16 @@ class MemorySummaryNode(Node): ) -class MutlimodalNode(Node): +class PerceptualNode(Node): """Node representing a multimodal message in the knowledge graph. - - Attributes: - dialog_id: ID of the parent dialog - message_id: ID of the message - metadata: Additional message metadata - embedding: Optional embedding vector for the message """ - dialog_id: str = Field(..., description="ID of the parent dialog") - message_id: str = Field(..., description="ID of the message") - summary: str = Field(..., description="The text content of the message") - file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") - file_path: List[str] = Field(..., description="List of file paths for multimodal content") - metadata: dict = Field(default_factory=dict, description="Additional message metadata") - embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") + perceptual_type: int + file_path: str + file_name: str + file_ext: str + summary: str + keywords: list[str] + topic: str + domain: str + file_type: str + summary_embedding: list[float] | None diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 2f8660af..66203067 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -30,6 +30,7 @@ class ConversationMessage(BaseModel): """ role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") + files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True) class TemporalValidityRange(BaseModel): @@ -130,7 +131,8 @@ class Chunk(BaseModel): content: str = Field(..., description="The content of the chunk as a string.") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") - chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") + files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") + chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 6e94a84f..da10c497 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,7 +31,9 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode + StatementNode, + PerceptualEdge, + PerceptualNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -170,9 +172,11 @@ class ExtractionOrchestrator: list[ChunkNode], list[StatementNode], list[ExtractedEntityNode], + list[PerceptualNode], list[StatementChunkEdge], list[StatementEntityEdge], list[EntityEntityEdge], + list[PerceptualEdge], dict ]: """ @@ -259,9 +263,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) = await self._create_nodes_and_edges(dialog_data_list) # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) @@ -275,7 +281,16 @@ class ExtractionOrchestrator: # 注意:deduplication 消息已在创建节点和边完成后立即发送 - result = await self._run_dedup_and_write_summary( + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + dialog_data_list, + ) = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, statement_nodes, @@ -287,7 +302,18 @@ class ExtractionOrchestrator: ) logger.info(f"知识提取流水线运行完成({mode_str})") - return result + return ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + perceptual_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + perceptual_edges, + dialog_data_list, + ) except Exception as e: logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) @@ -1000,9 +1026,11 @@ class ExtractionOrchestrator: List[ChunkNode], List[StatementNode], List[ExtractedEntityNode], + List[PerceptualNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge] + List[EntityEntityEdge], + List[PerceptualEdge] ]: """ 创建图数据库节点和边 @@ -1026,6 +1054,8 @@ class ExtractionOrchestrator: statement_chunk_edges = [] statement_entity_edges = [] entity_entity_edges = [] + perceptual_nodes = [] + perceptual_edges = [] # 用于去重的集合 entity_id_set = set() @@ -1069,6 +1099,46 @@ class ExtractionOrchestrator: metadata=chunk.metadata, ) chunk_nodes.append(chunk_node) + logger.error(f"chunk file: {chunk.files}") + + for p, file_type in chunk.files: + + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if self.embedder_client and p.summary: + try: + summary_embedding = (await self.embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + perceptual = PerceptualNode( + name=f"Perceptual_{p.id}", + **{ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "file_type": file_type, + "summary_embedding": summary_embedding, + }) + perceptual_nodes.append(perceptual) + perceptual_edges.append(PerceptualEdge( + source=perceptual.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + )) # 处理每个陈述句 for statement in chunk.statements: @@ -1248,9 +1318,11 @@ class ExtractionOrchestrator: chunk_nodes, statement_nodes, entity_nodes, + perceptual_nodes, statement_chunk_edges, statement_entity_edges, entity_entity_edges, + perceptual_edges ) async def _run_dedup_and_write_summary( diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index cbdad0fa..a28247e4 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -72,7 +72,6 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] - multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -104,19 +103,15 @@ class MemoryWriteNode(BaseNode): url=file_instence.value.url, file_type=file_instence.value.origin_file_type ).model_dump()) - multimodal_memories.append({ - "role": message.role, - "files": file_info - }) messages.append({ "role": message.role, - "content": self._render_template(content, variable_pool) + "content": self._render_template(content, variable_pool), + "files": file_info }) write_message_task.delay( end_user_id=end_user_id, message=messages, - file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcef..44a844d0 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,10 +2,11 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text +from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY from sqlalchemy.orm import relationship from sqlalchemy.sql import func + from app.db import Base diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 6fb41914..e64d19a3 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -408,6 +408,9 @@ class MemoryConfigRepository: "llm_id": db_config.llm_id, "embedding_id": db_config.embedding_id, "rerank_id": db_config.rerank_id, + "vision_id": db_config.vision_id, + "audio_id": db_config.audio_id, + "video_id": db_config.video_id, "enable_llm_dedup_blockwise": db_config.enable_llm_dedup_blockwise, "enable_llm_disambiguation": db_config.enable_llm_disambiguation, "deep_retrieval": db_config.deep_retrieval, diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f49227d3..fd95c793 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -1,14 +1,15 @@ -from sqlalchemy.orm import Session, joinedload, selectinload -from sqlalchemy import and_, or_, func, desc, select -from typing import List, Optional, Dict, Any, Tuple import uuid +from typing import List, Optional, Dict, Any, Tuple +from sqlalchemy import and_, or_, func, desc +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.schemas.model_schema import ( ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigQuery, ModelConfigQueryNew ) -from app.core.logging_config import get_db_logger # 获取数据库专用日志器 db_logger = get_db_logger() @@ -137,6 +138,9 @@ class ModelConfigRepository: type_values.append(ModelType.LLM) filters.append(ModelConfig.type.in_(type_values)) + if query.capability: + filters.append(ModelConfig.capability.contains(query.capability)) + if query.is_active is not None: filters.append(ModelConfig.is_active == query.is_active) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 3a017089..a53ca289 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,16 +1,19 @@ from typing import List, Optional +from app.core.logging_config import get_logger from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ - MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE + MEMORY_SUMMARY_NODE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector +logger = get_logger(__name__) + async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") - print(f"All end_user_id: {end_user_id} node and edge deleted successfully") + logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result @@ -25,7 +28,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn List of created node UUIDs or None if failed """ if not dialogues: - print("No dialogues to save") + logger.info("No dialogues to save") return [] try: @@ -50,11 +53,11 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") + logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: - print(f"Error creating dialogue nodes: {e}") + logger.info(f"Error creating dialogue nodes: {e}") return None @@ -69,7 +72,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC List of created node UUIDs or None if failed """ if not statements: - print("No statements to save") + logger.info("No statements to save") return [] try: @@ -122,11 +125,11 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} statement nodes") + logger.info(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: - print(f"Error creating statement nodes: {e}") + logger.info(f"Error creating statement nodes: {e}") return None @@ -141,7 +144,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> List of created chunk UUIDs or None if failed """ if not chunks: - print("No chunk nodes to add") + logger.info("No chunk nodes to add") return [] try: @@ -174,16 +177,18 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> ) created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} chunk nodes") + logger.info(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: - print(f"Error creating chunk nodes: {e}") + logger.info(f"Error creating chunk nodes: {e}") return None -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ - List[str]]: +async def add_memory_summary_nodes( + summaries: List[MemorySummaryNode], + connector: Neo4jConnector +) -> Optional[List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -194,7 +199,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector List of created summary node ids or None if failed """ if not summaries: - print("No memory summary nodes to add") + logger.info("No memory summary nodes to add") return [] try: @@ -220,110 +225,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector summaries=flattened ) created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") + logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: - print(f"Failed to save MemorySummary nodes to Neo4j: {e}") - return None - - -async def add_perceptual_nodes( - perceptuals: list, - connector: Neo4jConnector, - embedder_client=None, -) -> Optional[List[str]]: - """Add perceptual memory nodes to Neo4j in batch. - - Args: - perceptuals: List of MemoryPerceptualModel objects from PostgreSQL - connector: Neo4j connector instance - embedder_client: Optional embedder client for generating summary embeddings - - Returns: - List of created node UUIDs or None if failed - """ - if not perceptuals: - print("No perceptual nodes to add") - return [] - - try: - flattened = [] - for p in perceptuals: - meta = p.meta_data or {} - content_meta = meta.get("content", {}) - - # 生成 summary embedding(如果有 embedder_client) - summary_embedding = None - if embedder_client and p.summary: - try: - summary_embedding = (await embedder_client.response([p.summary]))[0] - except Exception as emb_err: - print(f"Failed to embed perceptual summary: {emb_err}") - - flattened.append({ - "id": str(p.id), - "end_user_id": str(p.end_user_id), - "perceptual_type": p.perceptual_type, - "file_path": p.file_path or "", - "file_name": p.file_name or "", - "file_ext": p.file_ext or "", - "summary": p.summary or "", - "keywords": content_meta.get("keywords", []), - "topic": content_meta.get("topic", ""), - "domain": content_meta.get("domain", ""), - "created_at": p.created_time.isoformat() if p.created_time else None, - "summary_embedding": summary_embedding, - }) - - result = await connector.execute_query( - PERCEPTUAL_NODE_SAVE, - perceptuals=flattened, - ) - created_uuids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") - return created_uuids - - except Exception as e: - print(f"Failed to save Perceptual nodes to Neo4j: {e}") - return None - - -async def add_perceptual_dialogue_edges( - perceptuals: list, - dialog_id: str, - connector: Neo4jConnector, -) -> Optional[List[str]]: - """Add edges between Perceptual nodes and Dialogue nodes. - - Args: - perceptuals: List of MemoryPerceptualModel objects - dialog_id: The dialogue ID (or ref_id) to link to - connector: Neo4j connector instance - - Returns: - List of created edge element IDs or None if failed - """ - if not perceptuals or not dialog_id: - return [] - - try: - edges = [] - for p in perceptuals: - edges.append({ - "perceptual_id": str(p.id), - "dialog_id": dialog_id, - "end_user_id": str(p.end_user_id), - "created_at": p.created_time.isoformat() if p.created_time else None, - }) - - result = await connector.execute_query( - PERCEPTUAL_DIALOGUE_EDGE_SAVE, - edges=edges, - ) - created_ids = [record.get("uuid") for record in result] - print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") - return created_ids - - except Exception as e: - print(f"Failed to save Perceptual-Dialogue edges: {e}") + logger.info(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 49dbe2a5..d70a30e9 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1003,60 +1003,69 @@ RETURN DISTINCT """ Graph_Node_query = """ - MATCH (n:MemorySummary) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 0 AS priority - LIMIT $limit +MATCH (n:MemorySummary) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 0 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Dialogue) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT 1 +MATCH (n:Dialogue) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT 1 - UNION ALL +UNION ALL - MATCH (n:Statement) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 1 AS priority - LIMIT $limit +MATCH (n:Statement) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 1 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:ExtractedEntity) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 2 AS priority - LIMIT $limit +MATCH (n:ExtractedEntity) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 2 AS priority +LIMIT $limit - UNION ALL +UNION ALL - MATCH (n:Chunk) - WHERE n.end_user_id = $end_user_id - RETURN - elementId(n) AS id, - labels(n) AS labels, - properties(n) AS properties, - 3 AS priority - LIMIT $limit +MATCH (n:Chunk) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 3 AS priority +LIMIT $limit - """ +UNION ALL +MATCH (n:Perceptual) +WHERE n.end_user_id = $end_user_id +RETURN + elementId(n) AS id, + labels(n) AS labels, + properties(n) AS properties, + 4 AS priority + +""" # ============================================================ # Community 节点 & BELONGS_TO_COMMUNITY 边 @@ -1340,19 +1349,19 @@ SET n += { topic: p.topic, domain: p.domain, created_at: p.created_at, + file_type: p.file_type, summary_embedding: p.summary_embedding } RETURN n.id AS uuid """ # 感知记忆与对话的关联边 -PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +PERCEPTUAL_CHUNK_EDGE_SAVE = """ UNWIND $edges AS edge MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) -MATCH (d:Dialogue {end_user_id: edge.end_user_id}) -WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id -MERGE (d)-[r:HAS_PERCEPTUAL]->(p) -SET r.end_user_id = edge.end_user_id, +MATCH (c:Chunk {id: edge.chunk_id, end_user_id: edge.end_user_id}) +MERGE (c)-[r:HAS_PERCEPTUAL]->(p) +ON CREATE SET r.end_user_id = edge.end_user_id, r.created_at = edge.created_at RETURN elementId(r) AS uuid """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 34497d5b..d78dcef6 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -22,13 +22,18 @@ from app.core.memory.models.graph_models import ( StatementNode, ExtractedEntityNode, EntityEntityEdge, + PerceptualNode, + PerceptualEdge, ) import logging + logger = logging.getLogger(__name__) + + async def save_entities_and_relationships( - entity_nodes: List[ExtractedEntityNode], - entity_entity_edges: List[EntityEntityEdge], - connector: Neo4jConnector + entity_nodes: List[ExtractedEntityNode], + entity_entity_edges: List[EntityEntityEdge], + connector: Neo4jConnector ): """Save entities and their relationships using graph models""" all_entities = [entity.model_dump() for entity in entity_nodes] @@ -73,8 +78,8 @@ async def save_entities_and_relationships( async def save_chunk_nodes( - chunk_nodes: List[ChunkNode], - connector: Neo4jConnector + chunk_nodes: List[ChunkNode], + connector: Neo4jConnector ): """Save chunk nodes using graph models""" if not chunk_nodes: @@ -89,8 +94,8 @@ async def save_chunk_nodes( async def save_statement_chunk_edges( - statement_chunk_edges: List[StatementChunkEdge], - connector: Neo4jConnector + statement_chunk_edges: List[StatementChunkEdge], + connector: Neo4jConnector ): """Save statement-chunk edges using graph models""" if not statement_chunk_edges: @@ -118,8 +123,8 @@ async def save_statement_chunk_edges( async def save_statement_entity_edges( - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + statement_entity_edges: List[StatementEntityEdge], + connector: Neo4jConnector ): """Save statement-entity edges using graph models""" if not statement_entity_edges: @@ -142,7 +147,7 @@ async def save_statement_entity_edges( if all_se_edges: try: await connector.execute_query( - STATEMENT_ENTITY_EDGE_SAVE, + STATEMENT_ENTITY_EDGE_SAVE, relationships=all_se_edges ) except Exception: @@ -154,9 +159,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List[ChunkNode], statement_nodes: List[StatementNode], entity_nodes: List[ExtractedEntityNode], + perceptual_nodes: List[PerceptualNode], entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], + perceptual_edges: List[PerceptualEdge], connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -169,9 +176,11 @@ async def save_dialog_and_statements_to_neo4j( chunk_nodes: List of ChunkNode objects to save statement_nodes: List of StatementNode objects to save entity_nodes: List of ExtractedEntityNode objects to save + perceptual_nodes: List of PerceptualNode objects to save entity_edges: List of EntityEntityEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save + perceptual_edges: List of PerceptualEdge objects to save connector: Neo4j connector instance Returns: @@ -190,7 +199,7 @@ async def save_dialog_and_statements_to_neo4j( result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) dialogue_uuids = [record["uuid"] async for record in result] results['dialogues'] = dialogue_uuids - print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") + logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") # 2. Save all chunk nodes in batch if chunk_nodes: @@ -201,6 +210,14 @@ async def save_dialog_and_statements_to_neo4j( results['chunks'] = chunk_uuids logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") + if perceptual_nodes: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE + perceptual_data = [node.model_dump() for node in perceptual_nodes] + result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data) + perceptual_uuids = [record["uuid"] async for record in result] + results["perceptuals"] = perceptual_uuids + logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j") + # 3. Save all statement nodes in batch if statement_nodes: from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE @@ -281,6 +298,22 @@ async def save_dialog_and_statements_to_neo4j( results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") + if perceptual_edges: + from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE + perceptual_edge_data = [] + for edge in perceptual_edges: + print(edge.source, edge.target) + perceptual_edge_data.append({ + "perceptual_id": edge.source, + "chunk_id": edge.target, + "end_user_id": edge.end_user_id, + "created_at": edge.created_at.isoformat() if edge.created_at else None, + }) + result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data) + perceptual_edges_uuids = [record["uuid"] async for record in result] + results['perceptual_chunk_edges'] = perceptual_edges_uuids + logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j") + return results try: @@ -304,9 +337,9 @@ async def save_dialog_and_statements_to_neo4j( def schedule_clustering_after_write( - entity_nodes: List, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + entity_nodes: List, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 写入 Neo4j 成功后,调度后台聚类任务。 @@ -325,14 +358,15 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id)) async def _trigger_clustering( - new_entity_ids: List[str], - end_user_id: str, - llm_model_id: Optional[str] = None, - embedding_model_id: Optional[str] = None, + new_entity_ids: List[str], + end_user_id: str, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 058f082d..668a84a8 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -81,6 +81,12 @@ class ModelConfig(ModelConfigBase): updated_at: datetime.datetime api_keys: List["ModelApiKey"] = [] + @staticmethod + def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str: + if not key or len(key) <= prefix + suffix: + return "*" * len(key) + return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:] + @field_validator("api_keys", mode="after") @classmethod def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: @@ -90,6 +96,15 @@ class ModelConfig(ModelConfigBase): def _serialize_created_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("api_keys", when_used="json") + def _serialize_api_keys(self, api_keys: List["ModelApiKey"]): + result = [] + for api_key in api_keys: + data = api_key.model_dump() + data["api_key"] = self.mask_api_key(api_key.api_key) + result.append(data) + return result + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -165,20 +180,20 @@ class ModelApiKey(ModelApiKeyBase): if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict): self.model_config_ids = [ mc.id for mc in self.model_configs - if hasattr(mc, 'id') - and not getattr(mc, 'is_composite', False) - and getattr(mc, 'name', None) == self.model_name + if hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name ] # 情况2:字典列表 elif isinstance(self.model_configs, list): self.model_config_ids = [ mc['id'] if isinstance(mc, dict) else mc.id for mc in self.model_configs - if ((isinstance(mc, dict) - and 'id' in mc + if ((isinstance(mc, dict) + and 'id' in mc and not mc.get('is_composite', False) - and mc.get('name') == self.model_name) or - (hasattr(mc, 'id') + and mc.get('name') == self.model_name) or + (hasattr(mc, 'id') and not getattr(mc, 'is_composite', False) and getattr(mc, 'name', None) == self.model_name)) ] @@ -193,11 +208,10 @@ class ModelApiKey(ModelApiKeyBase): validate_assignment=True # 确保赋值触发校验 ) - @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -211,6 +225,7 @@ class ModelConfigQuery(BaseModel): """模型配置查询Schema""" type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") + capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)") is_active: Optional[bool] = Field(None, description="激活状态筛选") is_public: Optional[bool] = Field(None, description="公开状态筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) @@ -228,6 +243,7 @@ class ModelConfigQueryNew(BaseModel): is_composite: Optional[bool] = Field(None, description="组合模型筛选") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelMarketplace(BaseModel): """模型广场响应Schema""" llm_models: List[ModelConfig] = [] @@ -304,7 +320,7 @@ class ModelBaseUpdate(BaseModel): class ModelBase(BaseModel): """基础模型Schema""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID name: str type: str @@ -327,6 +343,7 @@ class ModelBaseQuery(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelInfo(BaseModel): """模型信息Schema""" model_name: str = Field(..., description="模型名称") @@ -336,4 +353,3 @@ class ModelInfo(BaseModel): is_omni: bool = Field(default=False, description="是否为omni模型") model_type: ModelType = Field(..., description="模型类型") capability: List[str] = Field(default_factory=list, description="模型能力列表") - diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 875f02bb..8bb6538d 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -274,7 +274,6 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], - file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -287,7 +286,6 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write - files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -348,15 +346,15 @@ class MemoryAgentService: raise ValueError(error_msg) perceptual_serivce = MemoryPerceptualService(db) - file_content = [] - for message in file_messages: + for message in messages: + message["file_content"] = [] for file in message["files"]: file_object = await perceptual_serivce.generate_perceptual_memory( end_user_id=end_user_id, memory_config=memory_config, file=FileInput(**file) ) - file_content.append(file_object) + message["file_content"].append((file_object, file["type"])) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: @@ -368,7 +366,6 @@ class MemoryAgentService: await write_neo4j( end_user_id=end_user_id, messages=messages, - file_content=file_content, memory_config=memory_config, ref_id='', language=language @@ -380,19 +377,23 @@ class MemoryAgentService: if deleted: logger.info( f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") - return self.writer_messages_deal( - "success", - start_time, - end_user_id, - config_id, - message_text, - { - "status": "success", - "data": messages, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name - } - ) + for message in messages: + message["file_content"] = [ + perceptual[0].file_path for perceptual in message["file_content"] + ] + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index d6c1de87..8255dbbe 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -317,11 +317,11 @@ class MemoryPerceptualService: stmt = select(FileMetadata).where( FileMetadata.id == file_id ) - file = self.db.execute(stmt).scalar_one_or_none() + file_obj = self.db.execute(stmt).scalar_one_or_none() - if file: - filename = file.file_name - file_ext = file.file_ext + if file_obj: + filename = file_obj.file_name + file_ext = file_obj.file_ext except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index fc749157..4617946b 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -297,9 +297,12 @@ async def run_pilot_extraction( chunk_nodes, statement_nodes, entity_nodes, + _, statement_chunk_edges, statement_entity_edges, entity_edges, + _, + _ ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index d5d19e0d..29516acc 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1888,7 +1888,8 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_ "Chunk": ["content", "created_at"], "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"], "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"], - "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + "MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段 + "Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"] } # 获取该节点类型的白名单字段 diff --git a/api/app/tasks.py b/api/app/tasks.py index 8afb2194..f243eac3 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,14 +1080,12 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, - file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write - file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1099,9 +1097,6 @@ def write_message_task( Raises: Exception on failure """ - if file_messages is None: - file_messages = [] - logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"config_id={config_id} (type: {type(config_id).__name__}), " @@ -1146,7 +1141,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result @@ -2617,57 +2612,6 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } -@celery_app.task( - name="app.tasks.write_perceptual_memory", - bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=3600, - soft_time_limit=3300, -) -def write_perceptual_memory( - self, - end_user_id: str, - model_api_config: dict, - file_type: str, - file_url: str, - file_message: dict -): - """ - Write perceptual memory for a user into PostgreSQL and Neo4j. - - This task generates or updates the user's perceptual memory - in the backend databases. It is intended to be executed asynchronously - via Celery. - - Args: - end_user_id (uuid.UUID): The unique identifier of the end user. - model_api_config (ModelInfo): API configuration for the model - used to generate perceptual memory. - file_type (str): The file type - file_url (url): The url of file - file_message (dict): The file message containing details about the file - to be processed. - - Returns: - None - """ - file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() - set_asyncio_event_loop() - with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): - model_info = ModelInfo(**model_api_config) - with get_db_context() as db: - memory_perceptual_service = MemoryPerceptualService(db) - return asyncio.run(memory_perceptual_service.generate_perceptual_memory( - end_user_id, - model_info, - file_type, - file_url, - file_message, - )) - - # ============================================================================= # 社区聚类补全任务(触发型) # =============================================================================