feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -141,7 +141,7 @@ class AppChatService:
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -339,7 +339,7 @@ class AppChatService:
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||
|
||||
@@ -600,7 +600,7 @@ class AgentRunService:
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -836,7 +836,7 @@ class AgentRunService:
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
|
||||
@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.cache import InterestMemoryCache
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import (
|
||||
merge_multiple_search_results,
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
@@ -271,6 +274,7 @@ class MemoryAgentService:
|
||||
self,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
file_messages: list[dict],
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
@@ -283,6 +287,7 @@ class MemoryAgentService:
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
messages: Message to write
|
||||
files: Files to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -342,48 +347,52 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
perceptual_serivce = MemoryPerceptualService(db)
|
||||
file_content = []
|
||||
for message in file_messages:
|
||||
for file in message["files"]:
|
||||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
file=FileInput(**file)
|
||||
)
|
||||
file_content.append(file_object)
|
||||
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return "success"
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
print(100 * '-')
|
||||
print(langchain_messages)
|
||||
print(100 * '-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config,
|
||||
"language": language
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
file_content=file_content,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
)
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id, lang
|
||||
)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
return self.writer_messages_deal(
|
||||
"success",
|
||||
start_time,
|
||||
end_user_id,
|
||||
config_id,
|
||||
message_text,
|
||||
{
|
||||
"status": "success",
|
||||
"data": messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
|
||||
@@ -28,7 +28,7 @@ class MemoryAPIService:
|
||||
2. Maps end_user_id to end_user_id for memory operations
|
||||
3. Delegates to MemoryAgentService for actual memory read/write operations
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize MemoryAPIService.
|
||||
|
||||
@@ -36,11 +36,11 @@ class MemoryAPIService:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
|
||||
def validate_end_user(
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
self,
|
||||
end_user_id: str,
|
||||
workspace_id: uuid.UUID
|
||||
) -> EndUser:
|
||||
"""Validate that end_user exists and belongs to the workspace.
|
||||
|
||||
@@ -56,7 +56,7 @@ class MemoryAPIService:
|
||||
BusinessException: If end_user not in authorized workspace
|
||||
"""
|
||||
logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Query end_user by ID
|
||||
try:
|
||||
end_user_uuid = uuid.UUID(end_user_id)
|
||||
@@ -66,7 +66,7 @@ class MemoryAPIService:
|
||||
message=f"Invalid end_user_id format: {end_user_id}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
|
||||
|
||||
if not end_user:
|
||||
@@ -75,13 +75,13 @@ class MemoryAPIService:
|
||||
resource_type="EndUser",
|
||||
resource_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
# Verify end_user belongs to the workspace via App relationship
|
||||
app = self.db.query(App).filter(
|
||||
App.id == end_user.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
|
||||
if not app:
|
||||
logger.warning(f"App not found for end_user: {end_user_id}")
|
||||
# raise ResourceNotFoundException(
|
||||
@@ -99,7 +99,7 @@ class MemoryAPIService:
|
||||
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
||||
# code=BizCode.FORBIDDEN
|
||||
# )
|
||||
|
||||
|
||||
logger.info(f"End user {end_user_id} validated successfully")
|
||||
return end_user
|
||||
|
||||
@@ -125,13 +125,14 @@ class MemoryAPIService:
|
||||
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||
|
||||
async def write_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
config_id: str,
|
||||
storage_type: str = "neo4j",
|
||||
files: Optional[list]=None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Write memory with validation.
|
||||
|
||||
@@ -153,14 +154,16 @@ class MemoryAPIService:
|
||||
ResourceNotFoundException: If end_user not found
|
||||
BusinessException: If end_user not in authorized workspace or write fails
|
||||
"""
|
||||
if files is None:
|
||||
files = list()
|
||||
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
|
||||
# Update end user's memory_config_id
|
||||
self._update_end_user_config(end_user_id, config_id)
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||
@@ -171,11 +174,12 @@ class MemoryAPIService:
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id or ""
|
||||
user_rag_memory_id=user_rag_memory_id or "",
|
||||
files=files
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
|
||||
|
||||
# result may be a string "success" or a dict with a "status" key
|
||||
# Preserve the full dict so callers don't silently lose extra fields
|
||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||
@@ -189,7 +193,7 @@ class MemoryAPIService:
|
||||
"status": result if isinstance(result, str) else "success",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||
raise BusinessException(
|
||||
@@ -204,16 +208,16 @@ class MemoryAPIService:
|
||||
message=f"Memory write failed: {str(e)}",
|
||||
code=BizCode.MEMORY_WRITE_FAILED
|
||||
)
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
search_switch: str = "0",
|
||||
config_id: str = "",
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Read memory with validation.
|
||||
|
||||
@@ -237,14 +241,13 @@ class MemoryAPIService:
|
||||
BusinessException: If end_user not in authorized workspace or read fails
|
||||
"""
|
||||
logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}")
|
||||
|
||||
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
|
||||
# Update end user's memory_config_id
|
||||
self._update_end_user_config(end_user_id, config_id)
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
result = await MemoryAgentService().read_memory(
|
||||
@@ -257,15 +260,15 @@ class MemoryAPIService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id or ""
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {end_user_id}")
|
||||
|
||||
|
||||
return {
|
||||
"answer": result.get("answer", ""),
|
||||
"intermediate_outputs": result.get("intermediate_outputs", []),
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
|
||||
except ConfigurationError as e:
|
||||
logger.error(f"Memory configuration error for end_user {end_user_id}: {e}")
|
||||
raise BusinessException(
|
||||
@@ -282,8 +285,8 @@ class MemoryAPIService:
|
||||
)
|
||||
|
||||
def list_memory_configs(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Dict[str, Any]:
|
||||
"""List all memory configs for a workspace.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||
if isinstance(config_id, uuid.UUID):
|
||||
return config_id
|
||||
|
||||
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
@@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {config_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
config_id_stripped = config_id.strip()
|
||||
|
||||
|
||||
# Try parsing as UUID first
|
||||
try:
|
||||
return uuid.UUID(config_id_stripped)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# Fall back to integer parsing
|
||||
try:
|
||||
parsed_id = int(config_id_stripped)
|
||||
@@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None):
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
# 如果提供了数据库会话,尝试通过 user_id 查询 config_id
|
||||
if db is not None:
|
||||
# 查询 user_id 匹配的记录
|
||||
stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id))
|
||||
result = db.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
if result:
|
||||
logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}")
|
||||
return result.config_id
|
||||
|
||||
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
@@ -154,10 +154,10 @@ class MemoryConfigService:
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
self,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database with optional fallback.
|
||||
@@ -194,14 +194,14 @@ class MemoryConfigService:
|
||||
try:
|
||||
# Use get_config_with_fallback if workspace_id is provided
|
||||
memory_config = None
|
||||
validated_config_id = None
|
||||
if workspace_id:
|
||||
validated_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
except Exception:
|
||||
validated_config_id = None
|
||||
|
||||
|
||||
memory_config = self.get_config_with_fallback(
|
||||
memory_config_id=validated_config_id,
|
||||
workspace_id=workspace_id
|
||||
@@ -210,7 +210,7 @@ class MemoryConfigService:
|
||||
validated_config_id = _validate_config_id(config_id, self.db)
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
memory_config = self.db.get(MemoryConfigModel, validated_config_id)
|
||||
|
||||
|
||||
if not memory_config:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -233,7 +233,7 @@ class MemoryConfigService:
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
|
||||
|
||||
if not result:
|
||||
raise ConfigurationError(
|
||||
f"Workspace not found for config {memory_config.config_id}"
|
||||
@@ -243,10 +243,10 @@ class MemoryConfigService:
|
||||
|
||||
# Helper function to validate model with workspace fallback
|
||||
def _validate_model_with_fallback(
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
model_id: str,
|
||||
model_type: str,
|
||||
workspace_default: str,
|
||||
required: bool = False
|
||||
) -> tuple:
|
||||
"""Validate model ID, falling back to workspace default if invalid.
|
||||
|
||||
@@ -275,7 +275,7 @@ class MemoryConfigService:
|
||||
logger.warning(
|
||||
f"{model_type} model validation failed, trying workspace default: {e}"
|
||||
)
|
||||
|
||||
|
||||
# Fallback to workspace default
|
||||
if workspace_default:
|
||||
try:
|
||||
@@ -297,7 +297,7 @@ class MemoryConfigService:
|
||||
logger.error(f"Workspace default {model_type} model also invalid: {e}")
|
||||
if required:
|
||||
raise
|
||||
|
||||
|
||||
if required:
|
||||
raise InvalidConfigError(
|
||||
f"{model_type.title()} model is required but not configured",
|
||||
@@ -306,7 +306,7 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id
|
||||
)
|
||||
|
||||
|
||||
return None, None
|
||||
|
||||
# Step 2: Validate embedding model with workspace fallback
|
||||
@@ -343,6 +343,35 @@ class MemoryConfigService:
|
||||
if memory_config.rerank_id or workspace.rerank:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
vision_uuid, vision_name = validate_and_resolve_model_id(
|
||||
memory_config.vision_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
audio_uuid, audio_name = validate_and_resolve_model_id(
|
||||
memory_config.audio_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
video_uuid, video_name = validate_and_resolve_model_id(
|
||||
memory_config.video_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
@@ -356,6 +385,12 @@ class MemoryConfigService:
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
video_model_id=video_uuid,
|
||||
video_model_name=video_name,
|
||||
vision_model_id=vision_uuid,
|
||||
vision_model_name=vision_name,
|
||||
audio_model_id=audio_uuid,
|
||||
audio_model_name=audio_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
@@ -364,24 +399,31 @@ class MemoryConfigService:
|
||||
reflexion_baseline=memory_config.baseline or "Time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
enable_llm_dedup_blockwise=bool(
|
||||
memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(
|
||||
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
statement_granularity=int(
|
||||
memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(
|
||||
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(
|
||||
memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_enabled=bool(
|
||||
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
pruning_threshold=float(
|
||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
@@ -448,9 +490,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -481,9 +523,9 @@ class MemoryConfigService:
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
@@ -571,25 +613,25 @@ class MemoryConfigService:
|
||||
"""
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
if not memory_config.scene_id:
|
||||
logger.debug("No scene_id configured, skipping ontology type fetch")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
ontology_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id)
|
||||
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
return None
|
||||
|
||||
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
return ontology_types
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
@@ -598,8 +640,8 @@ class MemoryConfigService:
|
||||
return None
|
||||
|
||||
def get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
self,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get workspace default memory config.
|
||||
|
||||
@@ -613,19 +655,19 @@ class MemoryConfigService:
|
||||
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
||||
"""
|
||||
config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id)
|
||||
|
||||
|
||||
if not config:
|
||||
logger.warning(
|
||||
"No active memory config found for workspace fallback",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def get_config_with_fallback(
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
self,
|
||||
memory_config_id: Optional[UUID],
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get memory config with fallback to workspace default.
|
||||
|
||||
@@ -644,13 +686,13 @@ class MemoryConfigService:
|
||||
"No memory config ID provided, using workspace default",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
config = MemoryConfigRepository.get_with_fallback(
|
||||
self.db,
|
||||
memory_config_id,
|
||||
workspace_id
|
||||
)
|
||||
|
||||
|
||||
if not config and memory_config_id:
|
||||
logger.warning(
|
||||
"Memory config not found, falling back to workspace default",
|
||||
@@ -659,13 +701,13 @@ class MemoryConfigService:
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
def delete_config(
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
self,
|
||||
config_id: UUID | int,
|
||||
force: bool = False
|
||||
) -> dict:
|
||||
"""Delete memory config with protection against in-use configs.
|
||||
|
||||
@@ -687,7 +729,7 @@ class MemoryConfigService:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
|
||||
# 处理旧格式 int 类型的 config_id
|
||||
if isinstance(config_id, int):
|
||||
logger.warning(
|
||||
@@ -699,11 +741,11 @@ class MemoryConfigService:
|
||||
"message": "旧格式配置ID不支持删除操作,请使用新版配置",
|
||||
"legacy_int_id": config_id
|
||||
}
|
||||
|
||||
|
||||
config = self.db.get(MemoryConfigModel, config_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
||||
|
||||
|
||||
# Check if this is the default config - default configs cannot be deleted
|
||||
if config.is_default:
|
||||
logger.warning(
|
||||
@@ -715,11 +757,11 @@ class MemoryConfigService:
|
||||
"message": "默认配置不允许删除",
|
||||
"is_default": True
|
||||
}
|
||||
|
||||
|
||||
# Use repository to count connected end users
|
||||
end_user_repo = EndUserRepository(self.db)
|
||||
connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
||||
|
||||
|
||||
if connected_count > 0 and not force:
|
||||
logger.warning(
|
||||
"Attempted to delete memory config with connected end users",
|
||||
@@ -728,18 +770,18 @@ class MemoryConfigService:
|
||||
"connected_count": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置",
|
||||
"connected_count": connected_count,
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Force delete: use repository to clear end user references first
|
||||
if connected_count > 0 and force:
|
||||
cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
||||
|
||||
|
||||
logger.warning(
|
||||
"Force deleting memory config, clearing end user references",
|
||||
extra={
|
||||
@@ -747,11 +789,11 @@ class MemoryConfigService:
|
||||
"cleared_end_users": cleared_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
self.db.delete(config)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
logger.info(
|
||||
"Memory config deleted",
|
||||
extra={
|
||||
@@ -760,16 +802,16 @@ class MemoryConfigService:
|
||||
"affected_users": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "记忆配置删除成功",
|
||||
"affected_users": connected_count
|
||||
}
|
||||
|
||||
|
||||
except IntegrityError as e:
|
||||
self.db.rollback()
|
||||
|
||||
|
||||
# Handle foreign key violation gracefully
|
||||
error_str = str(e.orig) if e.orig else str(e)
|
||||
if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower():
|
||||
@@ -785,7 +827,7 @@ class MemoryConfigService:
|
||||
"message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除",
|
||||
"force_required": True
|
||||
}
|
||||
|
||||
|
||||
# Re-raise other integrity errors
|
||||
logger.error(
|
||||
"Delete failed due to integrity error",
|
||||
@@ -800,9 +842,9 @@ class MemoryConfigService:
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
self,
|
||||
app_type: str,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||
|
||||
@@ -828,8 +870,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_agent(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Agent 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -888,8 +930,8 @@ class MemoryConfigService:
|
||||
return None, False
|
||||
|
||||
def _extract_memory_config_id_from_workflow(
|
||||
self,
|
||||
config: dict
|
||||
self,
|
||||
config: dict
|
||||
) -> tuple[Optional[uuid.UUID], bool]:
|
||||
"""从 Workflow 应用配置中提取 memory_config_id
|
||||
|
||||
@@ -905,14 +947,14 @@ class MemoryConfigService:
|
||||
- is_legacy_int: 是否检测到旧格式 int 数据
|
||||
"""
|
||||
nodes = config.get("nodes", [])
|
||||
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type", "")
|
||||
|
||||
|
||||
# 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite)
|
||||
if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]:
|
||||
config_id = node.get("config", {}).get("config_id")
|
||||
|
||||
|
||||
if config_id:
|
||||
try:
|
||||
# 处理字符串、UUID 和 int(旧数据兼容)三种情况
|
||||
@@ -937,6 +979,6 @@ class MemoryConfigService:
|
||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||
f"node_type={node_type}, error={str(e)}"
|
||||
)
|
||||
|
||||
|
||||
logger.debug("工作流配置中未找到记忆节点")
|
||||
return None, False
|
||||
|
||||
@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import FileMetadata
|
||||
from app.models import FileMetadata, ModelApiKey, ModelType
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas import FileType, FileInput
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
def _get_mutlimodal_client(
|
||||
self,
|
||||
file_type: FileType,
|
||||
config: MemoryConfig
|
||||
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
|
||||
model_config = None
|
||||
if file_type == FileType.AUDIO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.audio_model_id
|
||||
)
|
||||
elif file_type == FileType.VIDEO:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.video_model_id
|
||||
)
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.llm_model_id
|
||||
)
|
||||
elif file_type == FileType.IMAGE:
|
||||
model_config = ModelApiKeyService.get_available_api_key(
|
||||
self.db,
|
||||
config.vision_model_id
|
||||
)
|
||||
llm = None
|
||||
if model_config:
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
)
|
||||
return llm, model_config
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
api_base=model_config.api_base,
|
||||
is_omni=model_config.is_omni,
|
||||
capability=model_config.capability,
|
||||
model_type=ModelType.LLM
|
||||
))
|
||||
file_message = await multimodel_service.process_files(
|
||||
files=[file]
|
||||
)
|
||||
if file_message:
|
||||
file_message = file_message[0]
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
for msg in content:
|
||||
if isinstance(msg, dict):
|
||||
final_output += msg.get("text", "")
|
||||
elif isinstance(msg, str):
|
||||
final_output += msg
|
||||
elif isinstance(content, dict):
|
||||
final_output += content.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
final_output = content
|
||||
else:
|
||||
raise ValueError(f"Unexcept Model Output Type: {result.content}")
|
||||
content = json_repair.repair_json(final_output, return_objects=True)
|
||||
path = urlparse(file.url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
|
||||
except ValueError:
|
||||
business_logger.debug(f"Remote file, file_id={filename}")
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
if file.type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
elif file.type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
elif file.type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
if file.type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene", [])
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
elif file.type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count", 0),
|
||||
"title": content.get("title", ""),
|
||||
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count", 0)
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file.type),
|
||||
file_path=file.url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary', ""),
|
||||
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
|
||||
@@ -9,14 +9,12 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import base64
|
||||
import csv
|
||||
import io
|
||||
import uuid
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import csv
|
||||
import json
|
||||
|
||||
import PyPDF2
|
||||
import httpx
|
||||
import magic
|
||||
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -342,15 +339,12 @@ class MultimodalService:
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -358,8 +352,6 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -380,23 +372,15 @@ class MultimodalService:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
is_support, content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
is_support, content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
if is_support:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -418,17 +402,6 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Reference in New Issue
Block a user