feat(memory): support perception-aware memory writing in workflow and Neo4j nodes

This commit is contained in:
Eternity
2026-03-23 16:33:25 +08:00
parent 31085ed678
commit 2ff81ba101
22 changed files with 820 additions and 519 deletions

View File

@@ -141,7 +141,7 @@ class AppChatService:
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态
@@ -339,7 +339,7 @@ class AppChatService:
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态同时并行启动 TTS

View File

@@ -600,7 +600,7 @@ class AgentRunService:
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -836,7 +836,7 @@ class AgentRunService:
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索

View File

@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID
import redis
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.cache import InterestMemoryCache
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from app.services.memory_perceptual_service import MemoryPerceptualService
try:
from app.core.memory.utils.log.audit_logger import audit_logger
@@ -271,6 +274,7 @@ class MemoryAgentService:
self,
end_user_id: str,
messages: list[dict],
file_messages: list[dict],
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
@@ -283,6 +287,7 @@ class MemoryAgentService:
Args:
end_user_id: Group identifier (also used as end_user_id)
messages: Message to write
files: Files to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
@@ -342,48 +347,52 @@ class MemoryAgentService:
raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
file_content = []
for message in file_messages:
for file in message["files"]:
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
file_content.append(file_object)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success"
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100 * '-')
print(langchain_messages)
print(100 * '-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config,
"language": language
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
file_content=file_content,
memory_config=memory_config,
ref_id='',
language=language
)
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang
)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
return self.writer_messages_deal(
"success",
start_time,
end_user_id,
config_id,
message_text,
{
"status": "success",
"data": messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name
}
)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"

View File

@@ -28,7 +28,7 @@ class MemoryAPIService:
2. Maps end_user_id to end_user_id for memory operations
3. Delegates to MemoryAgentService for actual memory read/write operations
"""
def __init__(self, db: Session):
"""Initialize MemoryAPIService.
@@ -36,11 +36,11 @@ class MemoryAPIService:
db: SQLAlchemy database session
"""
self.db = db
def validate_end_user(
self,
end_user_id: str,
workspace_id: uuid.UUID
self,
end_user_id: str,
workspace_id: uuid.UUID
) -> EndUser:
"""Validate that end_user exists and belongs to the workspace.
@@ -56,7 +56,7 @@ class MemoryAPIService:
BusinessException: If end_user not in authorized workspace
"""
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
# Query end_user by ID
try:
end_user_uuid = uuid.UUID(end_user_id)
@@ -66,7 +66,7 @@ class MemoryAPIService:
message=f"Invalid end_user_id format: {end_user_id}",
code=BizCode.INVALID_PARAMETER
)
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
if not end_user:
@@ -75,13 +75,13 @@ class MemoryAPIService:
resource_type="EndUser",
resource_id=end_user_id
)
# Verify end_user belongs to the workspace via App relationship
app = self.db.query(App).filter(
App.id == end_user.app_id,
App.is_active.is_(True)
).first()
if not app:
logger.warning(f"App not found for end_user: {end_user_id}")
# raise ResourceNotFoundException(
@@ -99,7 +99,7 @@ class MemoryAPIService:
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
# code=BizCode.FORBIDDEN
# )
logger.info(f"End user {end_user_id} validated successfully")
return end_user
@@ -125,13 +125,14 @@ class MemoryAPIService:
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
async def write_memory(
self,
workspace_id: uuid.UUID,
end_user_id: str,
message: str,
config_id: str,
storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None,
self,
workspace_id: uuid.UUID,
end_user_id: str,
message: str,
config_id: str,
storage_type: str = "neo4j",
files: Optional[list]=None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Write memory with validation.
@@ -153,14 +154,16 @@ class MemoryAPIService:
ResourceNotFoundException: If end_user not found
BusinessException: If end_user not in authorized workspace or write fails
"""
if files is None:
files = list()
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
try:
# Delegate to MemoryAgentService
# Convert string message to list[dict] format expected by MemoryAgentService
@@ -171,11 +174,12 @@ class MemoryAPIService:
config_id=config_id,
db=self.db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or ""
user_rag_memory_id=user_rag_memory_id or "",
files=files
)
logger.info(f"Memory write successful for end_user: {end_user_id}")
# result may be a string "success" or a dict with a "status" key
# Preserve the full dict so callers don't silently lose extra fields
# (e.g. error codes, metadata) returned by MemoryAgentService.
@@ -189,7 +193,7 @@ class MemoryAPIService:
"status": result if isinstance(result, str) else "success",
"end_user_id": end_user_id,
}
except ConfigurationError as e:
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
raise BusinessException(
@@ -204,16 +208,16 @@ class MemoryAPIService:
message=f"Memory write failed: {str(e)}",
code=BizCode.MEMORY_WRITE_FAILED
)
async def read_memory(
self,
workspace_id: uuid.UUID,
end_user_id: str,
message: str,
search_switch: str = "0",
config_id: str = "",
storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None,
self,
workspace_id: uuid.UUID,
end_user_id: str,
message: str,
search_switch: str = "0",
config_id: str = "",
storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Read memory with validation.
@@ -237,14 +241,13 @@ class MemoryAPIService:
BusinessException: If end_user not in authorized workspace or read fails
"""
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
try:
# Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory(
@@ -257,15 +260,15 @@ class MemoryAPIService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or ""
)
logger.info(f"Memory read successful for end_user: {end_user_id}")
return {
"answer": result.get("answer", ""),
"intermediate_outputs": result.get("intermediate_outputs", []),
"end_user_id": end_user_id
}
except ConfigurationError as e:
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
raise BusinessException(
@@ -282,8 +285,8 @@ class MemoryAPIService:
)
def list_memory_configs(
self,
workspace_id: uuid.UUID,
self,
workspace_id: uuid.UUID,
) -> Dict[str, Any]:
"""List all memory configs for a workspace.

View File

@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
if isinstance(config_id, uuid.UUID):
return config_id
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
if result:
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
return result.config_id
return config_id
if isinstance(config_id, str):
config_id_stripped = config_id.strip()
# Try parsing as UUID first
try:
return uuid.UUID(config_id_stripped)
except ValueError:
pass
# Fall back to integer parsing
try:
parsed_id = int(config_id_stripped)
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
field_name="config_id",
invalid_value=config_id,
)
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
if db is not None:
# 查询 user_id 匹配的记录
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
result = db.execute(stmt).scalars().first()
if result:
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
return result.config_id
return parsed_id
except ValueError:
raise InvalidConfigError(
@@ -154,10 +154,10 @@ class MemoryConfigService:
self.db = db
def load_memory_config(
self,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
self,
config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database with optional fallback.
@@ -194,14 +194,14 @@ class MemoryConfigService:
try:
# Use get_config_with_fallback if workspace_id is provided
memory_config = None
validated_config_id = None
if workspace_id:
validated_config_id = None
if config_id:
try:
validated_config_id = _validate_config_id(config_id, self.db)
except Exception:
validated_config_id = None
memory_config = self.get_config_with_fallback(
memory_config_id=validated_config_id,
workspace_id=workspace_id
@@ -210,7 +210,7 @@ class MemoryConfigService:
validated_config_id = _validate_config_id(config_id, self.db)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
if not memory_config:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
@@ -233,7 +233,7 @@ class MemoryConfigService:
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(
f"Workspace not found for config {memory_config.config_id}"
@@ -243,10 +243,10 @@ class MemoryConfigService:
# Helper function to validate model with workspace fallback
def _validate_model_with_fallback(
model_id: str,
model_type: str,
workspace_default: str,
required: bool = False
model_id: str,
model_type: str,
workspace_default: str,
required: bool = False
) -> tuple:
"""Validate model ID, falling back to workspace default if invalid.
@@ -275,7 +275,7 @@ class MemoryConfigService:
logger.warning(
f"{model_type} model validation failed, trying workspace default: {e}"
)
# Fallback to workspace default
if workspace_default:
try:
@@ -297,7 +297,7 @@ class MemoryConfigService:
logger.error(f"Workspace default {model_type} model also invalid: {e}")
if required:
raise
if required:
raise InvalidConfigError(
f"{model_type.title()} model is required but not configured",
@@ -306,7 +306,7 @@ class MemoryConfigService:
config_id=validated_config_id,
workspace_id=workspace.id
)
return None, None
# Step 2: Validate embedding model with workspace fallback
@@ -343,6 +343,35 @@ class MemoryConfigService:
if memory_config.rerank_id or workspace.rerank:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
vision_uuid, vision_name = validate_and_resolve_model_id(
memory_config.vision_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
audio_uuid, audio_name = validate_and_resolve_model_id(
memory_config.audio_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
video_uuid, video_name = validate_and_resolve_model_id(
memory_config.video_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
@@ -356,6 +385,12 @@ class MemoryConfigService:
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
video_model_id=video_uuid,
video_model_name=video_name,
vision_model_id=vision_uuid,
vision_model_name=vision_name,
audio_model_id=audio_uuid,
audio_model_name=audio_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
@@ -364,24 +399,31 @@ class MemoryConfigService:
reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
enable_llm_dedup_blockwise=bool(
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
statement_granularity=int(
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(
memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_enabled=bool(
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
pruning_threshold=float(
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association
scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -448,9 +490,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -481,9 +523,9 @@ class MemoryConfigService:
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
@@ -571,25 +613,25 @@ class MemoryConfigService:
"""
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
if not memory_config.scene_id:
logger.debug("No scene_id configured, skipping ontology type fetch")
return None
try:
ontology_repo = OntologyClassRepository(self.db)
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
if not ontology_classes:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
return None
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
return ontology_types
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
@@ -598,8 +640,8 @@ class MemoryConfigService:
return None
def get_workspace_default_config(
self,
workspace_id: UUID
self,
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config.
@@ -613,19 +655,19 @@ class MemoryConfigService:
Optional[MemoryConfigModel]: Default config or None if no configs exist
"""
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
if not config:
logger.warning(
"No active memory config found for workspace fallback",
extra={"workspace_id": str(workspace_id)}
)
return config
def get_config_with_fallback(
self,
memory_config_id: Optional[UUID],
workspace_id: UUID
self,
memory_config_id: Optional[UUID],
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default.
@@ -644,13 +686,13 @@ class MemoryConfigService:
"No memory config ID provided, using workspace default",
extra={"workspace_id": str(workspace_id)}
)
config = MemoryConfigRepository.get_with_fallback(
self.db,
memory_config_id,
workspace_id
)
if not config and memory_config_id:
logger.warning(
"Memory config not found, falling back to workspace default",
@@ -659,13 +701,13 @@ class MemoryConfigService:
"workspace_id": str(workspace_id)
}
)
return config
def delete_config(
self,
config_id: UUID | int,
force: bool = False
self,
config_id: UUID | int,
force: bool = False
) -> dict:
"""Delete memory config with protection against in-use configs.
@@ -687,7 +729,7 @@ class MemoryConfigService:
from app.core.exceptions import ResourceNotFoundException
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.end_user_repository import EndUserRepository
# 处理旧格式 int 类型的 config_id
if isinstance(config_id, int):
logger.warning(
@@ -699,11 +741,11 @@ class MemoryConfigService:
"message": "旧格式配置ID不支持删除操作请使用新版配置",
"legacy_int_id": config_id
}
config = self.db.get(MemoryConfigModel, config_id)
if not config:
raise ResourceNotFoundException("MemoryConfig", str(config_id))
# Check if this is the default config - default configs cannot be deleted
if config.is_default:
logger.warning(
@@ -715,11 +757,11 @@ class MemoryConfigService:
"message": "默认配置不允许删除",
"is_default": True
}
# Use repository to count connected end users
end_user_repo = EndUserRepository(self.db)
connected_count = end_user_repo.count_by_memory_config_id(config_id)
if connected_count > 0 and not force:
logger.warning(
"Attempted to delete memory config with connected end users",
@@ -728,18 +770,18 @@ class MemoryConfigService:
"connected_count": connected_count
}
)
return {
"status": "warning",
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
"connected_count": connected_count,
"force_required": True
}
# Force delete: use repository to clear end user references first
if connected_count > 0 and force:
cleared_count = end_user_repo.clear_memory_config_id(config_id)
logger.warning(
"Force deleting memory config, clearing end user references",
extra={
@@ -747,11 +789,11 @@ class MemoryConfigService:
"cleared_end_users": cleared_count
}
)
try:
self.db.delete(config)
self.db.commit()
logger.info(
"Memory config deleted",
extra={
@@ -760,16 +802,16 @@ class MemoryConfigService:
"affected_users": connected_count
}
)
return {
"status": "success",
"message": "记忆配置删除成功",
"affected_users": connected_count
}
except IntegrityError as e:
self.db.rollback()
# Handle foreign key violation gracefully
error_str = str(e.orig) if e.orig else str(e)
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
@@ -785,7 +827,7 @@ class MemoryConfigService:
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
"force_required": True
}
# Re-raise other integrity errors
logger.error(
"Delete failed due to integrity error",
@@ -800,9 +842,9 @@ class MemoryConfigService:
# ==================== 记忆配置提取方法 ====================
def extract_memory_config_id(
self,
app_type: str,
config: dict
self,
app_type: str,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id根据应用类型分发
@@ -828,8 +870,8 @@ class MemoryConfigService:
return None, False
def _extract_memory_config_id_from_agent(
self,
config: dict
self,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Agent 应用配置中提取 memory_config_id
@@ -888,8 +930,8 @@ class MemoryConfigService:
return None, False
def _extract_memory_config_id_from_workflow(
self,
config: dict
self,
config: dict
) -> tuple[Optional[uuid.UUID], bool]:
"""从 Workflow 应用配置中提取 memory_config_id
@@ -905,14 +947,14 @@ class MemoryConfigService:
- is_legacy_int: 是否检测到旧格式 int 数据
"""
nodes = config.get("nodes", [])
for node in nodes:
node_type = node.get("type", "")
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
config_id = node.get("config", {}).get("config_id")
if config_id:
try:
# 处理字符串、UUID 和 int旧数据兼容三种情况
@@ -937,6 +979,6 @@ class MemoryConfigService:
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
f"node_type={node_type}, error={str(e)}"
)
logger.debug("工作流配置中未找到记忆节点")
return None, False

View File

@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import FileMetadata
from app.models import FileMetadata, ModelApiKey, ModelType
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType
from app.schemas import FileType, FileInput
from app.schemas.memory_config_schema import MemoryConfig
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
AudioModal, Content, VideoModal, TextModal
)
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
business_logger = get_business_logger()
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
def _get_mutlimodal_client(
self,
file_type: FileType,
config: MemoryConfig
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
model_config = None
if file_type == FileType.AUDIO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.audio_model_id
)
elif file_type == FileType.VIDEO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.video_model_id
)
elif file_type == FileType.DOCUMENT:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.llm_model_id
)
elif file_type == FileType.IMAGE:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.vision_model_id
)
llm = None
if model_config:
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
)
)
return llm, model_config
async def generate_perceptual_memory(
self,
end_user_id: str,
model_config: ModelInfo,
file_type: str,
file_url: str,
file_message: dict,
memory_config: MemoryConfig,
file: FileInput
):
memories = self.repository.get_by_url(file_url)
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}")
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
self.repository.create_perceptual_memory(
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
meta_data=memory_cache.meta_data
)
self.db.commit()
return
llm = RedBearLLM(RedBearModelConfig(
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
), type=model_config.model_type)
api_base=model_config.api_base,
is_omni=model_config.is_omni,
capability=model_config.capability,
model_type=ModelType.LLM
))
file_message = await multimodel_service.process_files(
files=[file]
)
if file_message:
file_message = file_message[0]
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
]}
]
result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True)
path = urlparse(file_url).path
content = result.content
final_output = ""
if isinstance(content, list):
for msg in content:
if isinstance(msg, dict):
final_output += msg.get("text", "")
elif isinstance(msg, str):
final_output += msg
elif isinstance(content, dict):
final_output += content.get("text", "")
elif isinstance(content, str):
final_output = content
else:
raise ValueError(f"Unexcept Model Output Type: {result.content}")
content = json_repair.repair_json(final_output, return_objects=True)
path = urlparse(file.url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
except ValueError:
business_logger.debug(f"Remote file, file_id={filename}")
if not file_ext:
if file_type == FileType.AUDIO:
if file.type == FileType.AUDIO:
file_ext = ".mp3"
elif file_type == FileType.VIDEO:
elif file.type == FileType.VIDEO:
file_ext = ".mp4"
elif file_type == FileType.DOCUMENT:
elif file.type == FileType.DOCUMENT:
file_ext = ".txt"
elif file_type == FileType.IMAGE:
elif file.type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
"topic": content.get("topic"),
"domain": content.get("domain")
}
if file_type in [FileType.IMAGE, FileType.VIDEO]:
if file.type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene", [])
}
elif file_type in [FileType.DOCUMENT]:
elif file.type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count", 0),
"title": content.get("title", ""),
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
file_modalities = {
"speaker_count": content.get("speaker_count", 0)
}
self.repository.create_perceptual_memory(
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type),
file_path=file_url,
perceptual_type=PerceptualType.trans_from_file_type(file.type),
file_path=file.url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary', ""),
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
}
)
self.db.commit()
return memory

View File

@@ -9,14 +9,12 @@
- OpenAI: 支持 URL 和 base64 格式
"""
import base64
import csv
import io
import uuid
import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import csv
import json
import PyPDF2
import httpx
import magic
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -342,15 +339,12 @@ class MultimodalService:
async def process_files(
self,
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
end_user_id: 用户ID
files: 文件输入列表
Returns:
@@ -358,8 +352,6 @@ class MultimodalService:
"""
if not files:
return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -380,23 +372,15 @@ class MultimodalService:
if file.type == FileType.IMAGE and "vision" in self.capability:
is_support, content = await self._process_image(file, strategy)
result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT:
is_support, content = await self._process_document(file, strategy)
result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy)
result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability:
is_support, content = await self._process_video(file, strategy)
result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
@@ -418,17 +402,6 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
"""
处理图片文件