diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py
index f758dd15..aac2aa84 100644
--- a/api/app/aioRedis.py
+++ b/api/app/aioRedis.py
@@ -1,10 +1,11 @@
-import os
import asyncio
import json
import logging
from typing import Dict, Any, Optional
+
import redis.asyncio as redis
from redis.asyncio import ConnectionPool
+
from app.core.config import settings
# 设置日志记录器
diff --git a/api/app/celery_app.py b/api/app/celery_app.py
index 21ee291d..60c22855 100644
--- a/api/app/celery_app.py
+++ b/api/app/celery_app.py
@@ -62,7 +62,7 @@ celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
-
+
# 时区
timezone='Asia/Shanghai',
enable_utc=False,
@@ -70,43 +70,44 @@ celery_app.conf.update(
# 任务追踪
task_track_started=True,
task_ignore_result=False,
-
+
# 超时设置
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
-
+
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
-
+
# 结果过期时间
result_expires=3600, # 结果保存1小时
-
+
# 任务确认设置
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
-
+
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
-
+
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
-
+ 'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
+
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
-
+
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
-
+
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
@@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab(
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
-#构建定时任务配置
+# 构建定时任务配置
beat_schedule_config = {
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py
index 7d3ee686..4ea4fee1 100644
--- a/api/app/celery_worker.py
+++ b/api/app/celery_worker.py
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks
-__all__ = ['celery_app']
\ No newline at end of file
+__all__ = ['celery_app']
diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py
index 5d1551d8..31451a7d 100644
--- a/api/app/controllers/app_controller.py
+++ b/api/app/controllers/app_controller.py
@@ -808,6 +808,15 @@ async def draft_run_compare(
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
+ if payload.user_id is None:
+ end_user_repo = EndUserRepository(db)
+ new_end_user = end_user_repo.get_or_create_end_user(
+ app_id=app_id,
+ other_id=str(current_user.id),
+ original_user_id=str(current_user.id) # Save original user_id to other_id
+ )
+ payload.user_id = str(new_end_user.id)
+
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
@@ -853,6 +862,8 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
+
+
# 流式返回
if payload.stream:
async def event_generator():
@@ -864,7 +875,7 @@ async def draft_run_compare(
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
- user_id=payload.user_id or str(current_user.id),
+ user_id=payload.user_id,
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
@@ -895,7 +906,7 @@ async def draft_run_compare(
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
- user_id=payload.user_id or str(current_user.id),
+ user_id=payload.user_id,
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py
index 829f26c4..ca08db76 100644
--- a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py
+++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py
@@ -4,13 +4,13 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState:
"""
Start node - Extract content and maintain state information
-
+
Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information.
-
+
Args:
state: ReadState containing messages and other state data
-
+
Returns:
ReadState: Updated state with extracted content in data field
"""
@@ -19,19 +19,20 @@ def content_input_node(state: ReadState) -> ReadState:
# Return content and maintain all state information
return {"data": content}
+
def content_input_write(state: WriteState) -> WriteState:
"""
Start node - Extract content and maintain state information for write operations
-
+
Extracts the content from the first message in the state for write operations.
-
+
Args:
state: WriteState containing messages and other state data
-
+
Returns:
WriteState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# Return content and maintain all state information
- return {"data": content}
\ No newline at end of file
+ return {"data": content}
diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py
index 004e03b3..d6ca3333 100644
--- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py
+++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py
@@ -1,13 +1,13 @@
-
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
-
logger = get_agent_logger(__name__)
counter = COUNTState(limit=3)
-def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
+
+
+def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
+
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
+
+
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
- status=state.get('verify', '')['status']
+ status = state.get('verify', '')['status']
# loop_count = counter.get_total()
if "success" in status:
# counter.reset()
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
- # counter.reset()
+ # counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
index ddb6ca3e..6176caf5 100644
--- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
+++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py
@@ -2,32 +2,32 @@ import json
import os
from app.core.logging_config import get_agent_logger
-from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
-from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
-
+from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
-from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
+from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
-from app.db import get_db_context, get_db
+from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
+
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
+
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
-
+
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
-
+
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
@@ -38,14 +38,24 @@ async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
-async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
- actual_config_id, long_term_messages=[]):
+
+
+async def write(
+ storage_type,
+ end_user_id,
+ user_message,
+ ai_message,
+ user_rag_memory_id,
+ actual_end_user_id,
+ actual_config_id,
+ long_term_messages=None
+):
"""
Write memory with structured message support
-
+
Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing.
-
+
Args:
storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: Terminal user identifier
@@ -55,7 +65,7 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
actual_end_user_id: Actual user identifier for storage
actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing
-
+
Logic explanation:
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j mode: Uses structured message lists
@@ -64,8 +74,9 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
3. Each message is converted to independent Chunk, preserving speaker field
"""
- db = next(get_db())
- try:
+ if long_term_messages is None:
+ long_term_messages = []
+ with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j mode: Use structured message lists
structured_messages = []
@@ -105,17 +116,16 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
- finally:
- db.close()
-async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
+
+async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
-
+
Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage.
-
+
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
@@ -126,13 +136,12 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
-
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
- if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
+ if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
- if len(chunk_data)==scope:
+ if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
@@ -142,22 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
logger.info(f'写入短长期:')
-
"""Window-based dialogue processing"""
-async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
+
+
+async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
-
+
Manages conversation data based on a sliding window approach. When the window
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
-
+
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
- scope=scope
+ scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -174,42 +184,53 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
config_id = memory_config.config_id
else:
config_id = memory_config
-
- await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
- config_id, formatted_messages)
+
+ await write(
+ AgentMemory_Long_Term.STORAGE_NEO4J,
+ end_user_id,
+ "",
+ "",
+ None,
+ end_user_id,
+ config_id,
+ formatted_messages
+ )
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
-async def memory_long_term_storage(end_user_id,memory_config,time):
+
+
+async def memory_long_term_storage(end_user_id, memory_config, time):
"""
Process memory storage based on time intervals and write to Neo4j
-
+
Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation.
-
+
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
time: Time interval for data retrieval
"""
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
- format_messages = (long_time_data)
- messages=[]
- memory_config=memory_config.config_id
+ format_messages = long_time_data
+ messages = []
+ memory_config = memory_config.config_id
for i in format_messages:
- message=json.loads(i['Query'])
- messages+= message
- if format_messages!=[]:
+ message = json.loads(i['Query'])
+ messages += message
+ if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
-"""Aggregation judgment processing"""
+
+
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
Aggregation judgment function: determine if input sentence and historical messages describe the same event
-
+
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval.
@@ -217,11 +238,11 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings
-
+
Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output
"""
-
+ history = None
try:
# 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
@@ -255,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
output_value = structured.output
if isinstance(output_value, list):
output_value = [
- {"role": msg.role, "content": msg.content}
+ {"role": msg.role, "content": msg.content}
for msg in output_value
]
@@ -268,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
-
+
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
-
+
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
- }
\ No newline at end of file
+ }
diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py
index bee77ddf..9bd2b2cf 100644
--- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py
+++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py
@@ -2,26 +2,25 @@ import asyncio
import json
from datetime import datetime, timedelta
-
from langchain.tools import tool
from pydantic import BaseModel, Field
-
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
+
def extract_tool_message_content(response):
"""
Extract ToolMessage content and tool names from agent response
-
+
Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data.
-
+
Args:
response: Agent response dictionary containing messages
-
+
Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool
@@ -61,10 +60,10 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""
Input schema for time retrieval tool
-
+
Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters.
-
+
Attributes:
context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user
@@ -72,25 +71,26 @@ class TimeRetrievalInput(BaseModel):
context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
+
def create_time_retrieval_tool(end_user_id: str):
"""
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
-
+
Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results.
-
+
Args:
end_user_id: User identifier for scoping search results
-
+
Returns:
function: Configured TimeRetrievalWithGroupId tool function
"""
-
+
def clean_temporal_result_fields(data):
"""
Clean unnecessary fields from temporal search results and modify structure
-
+
Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability.
@@ -102,10 +102,10 @@ def create_time_retrieval_tool(end_user_id: str):
"""
# List of fields to filter out
fields_to_remove = {
- 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
+ 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
-
+
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
@@ -126,15 +126,16 @@ def create_time_retrieval_tool(end_user_id: str):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
-
+
@tool
- def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
+ def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
+ end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
-
+
Performs time-based search operations with automatic metadata filtering. Supports
flexible date range specification and provides clean, user-friendly output.
-
+
Explicit parameters:
- context: Query context content
- start_date: Start time (optional, format: YYYY-MM-DD)
@@ -142,10 +143,11 @@ def create_time_retrieval_tool(end_user_id: str):
- end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
-
+
Returns:
str: JSON formatted search results with temporal data
"""
+
async def _async_search():
# Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id
@@ -167,18 +169,19 @@ def create_time_retrieval_tool(end_user_id: str):
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
-
+
return asyncio.run(_async_search())
@tool
- def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
+ def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
+ clean_output: bool = True) -> str:
"""
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
-
+
Performs combined keyword and temporal search operations with automatic metadata
filtering. Provides more targeted search results by combining content relevance
with time-based filtering.
-
+
Explicit parameters:
- context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days
@@ -186,10 +189,11 @@ def create_time_retrieval_tool(end_user_id: str):
- end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
-
+
Returns:
str: JSON formatted search results combining keyword and temporal data
"""
+
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
@@ -212,29 +216,29 @@ def create_time_retrieval_tool(end_user_id: str):
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
-
+
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
-
+
Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting.
Args:
memory_config: Memory configuration object containing LLM and search settings
**search_params: Search parameters including end_user_id, limit, include, etc.
-
+
Returns:
function: Configured HybridSearch tool function with async capabilities
"""
-
+
def clean_result_fields(data):
"""
Recursively clean unnecessary fields from results
-
+
Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size.
@@ -247,11 +251,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
# List of fields to filter out
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = {
- 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
- 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
- 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
+ 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
+ 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
+ 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
}
-
+
if isinstance(data, dict):
# Clean dictionary
cleaned = {}
@@ -265,7 +269,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
else:
# Return other types directly
return data
-
+
@tool
async def HybridSearch(
context: str,
@@ -279,7 +283,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
) -> str:
"""
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
-
+
Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output.
@@ -292,7 +296,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: Whether to use LLM-based reranking
clean_output: Whether to clean metadata fields from output
-
+
Returns:
str: JSON formatted comprehensive search results
"""
@@ -329,9 +333,9 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"search_type": search_type,
"results": cleaned_results
}
-
+
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
-
+
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
@@ -340,35 +344,36 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
-
+
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
-
+
Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments.
Args:
memory_config: Memory configuration object containing search settings
**search_params: Search parameters for configuration
-
+
Returns:
function: Configured HybridSearchSync tool function
"""
+
@tool
def HybridSearchSync(
- context: str,
- search_type: str = "hybrid",
- limit: int = 10,
- end_user_id: str = None,
- clean_output: bool = True
+ context: str,
+ search_type: str = "hybrid",
+ limit: int = 10,
+ end_user_id: str = None,
+ clean_output: bool = True
) -> str:
"""
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
-
+
Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion.
@@ -378,10 +383,11 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
limit: Result quantity limit
end_user_id: Group ID for filtering search results
clean_output: Whether to clean metadata fields from output
-
+
Returns:
str: JSON formatted search results
"""
+
async def _async_search():
# Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
@@ -392,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"end_user_id": end_user_id,
"clean_output": clean_output
})
-
+
return asyncio.run(_async_search())
-
- return HybridSearchSync
\ No newline at end of file
+
+ return HybridSearchSync
diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py
index aa5e09a6..e11a2085 100644
--- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py
+++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py
@@ -1,10 +1,12 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
-async def format_parsing(messages: list,type:str='string'):
+
+
+async def format_parsing(messages: list, type: str = 'string'):
"""
Format and parse message lists into different output types
-
+
Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization.
@@ -19,8 +21,8 @@ async def format_parsing(messages: list,type:str='string'):
- 'dict': List of dictionaries mapping user messages to AI responses
"""
result = []
- user=[]
- ai=[]
+ user = []
+ ai = []
for message in messages:
hstory_messages = message['messages']
@@ -30,37 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role']
content = content['content']
if type == "string":
- if role == 'human' or role=="user":
+ if role == 'human' or role == "user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
- if type == "dict" :
- if role == 'human' or role=="user":
- user.append( content)
+ if type == "dict":
+ if role == 'human' or role == "user":
+ user.append(content)
else:
ai.append(content)
if type == "dict":
- for key,values in zip(user,ai):
- result.append({key:values})
+ for key, values in zip(user, ai):
+ result.append({key: values})
return result
+
async def messages_parse(messages: list | dict):
"""
Parse messages from storage format into user-AI conversation pairs
-
+
Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage.
-
+
Args:
messages: List or dictionary containing stored message data with Query fields
-
+
Returns:
list: List of dictionaries containing user-AI message pairs for database storage
"""
- user=[]
- ai=[]
- database=[]
+ user = []
+ ai = []
+ database = []
for message in messages:
Query = message['Query']
Query = json.loads(Query)
@@ -72,20 +75,20 @@ async def messages_parse(messages: list | dict):
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
- return database
+ return database
-async def agent_chat_messages(user_content,ai_content):
+async def agent_chat_messages(user_content, ai_content):
"""
Create structured chat message format for agent conversations
-
+
Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects.
-
+
Args:
user_content: User's message content string
ai_content: AI's response content string
-
+
Returns:
list: List of structured message dictionaries with role and content fields
"""
diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py
index 15009955..bf3c6597 100644
--- a/api/app/core/memory/agent/langgraph_graph/write_graph.py
+++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
-
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -42,13 +41,15 @@ async def make_write_graph():
yield graph
-async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
+
+async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
+ end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
-
- Supports multiple storage strategies including chunk-based, time-based,
+
+ Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence.
-
+
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
@@ -56,9 +57,10 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
- from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
+ from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
+ aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
- write_store.save_session_write(end_user_id, (langchain_messages))
+ write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
@@ -66,25 +68,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
- if long_term_type==AgentMemory_Long_Term.STRATEGY_CHUNK:
+ if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
- await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
- if long_term_type==AgentMemory_Long_Term.STRATEGY_TIME:
+ await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
+ if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
- await memory_long_term_storage(end_user_id, memory_config,AgentMemory_Long_Term.TIME_SCOPE)
- if long_term_type==AgentMemory_Long_Term.STRATEGY_AGGREGATE:
+ await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
+ if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
-
-async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
+async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
-
+
Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages.
-
+
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
@@ -95,7 +96,7 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
- from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
+ from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
@@ -128,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
#
# if __name__ == "__main__":
# import asyncio
-# asyncio.run(main())
\ No newline at end of file
+# asyncio.run(main())
diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py
index 1c183422..ea8add48 100644
--- a/api/app/core/memory/agent/utils/llm_tools.py
+++ b/api/app/core/memory/agent/utils/llm_tools.py
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
+
class WriteState(TypedDict):
- '''
+ """
Langgrapg Writing TypedDict
- '''
+ """
messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
data: str
language: str # 语言类型 ("zh" 中文, "en" 英文)
+
class ReadState(TypedDict):
"""
LangGraph 工作流状态定义
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果
- problem_extension:dict
+ problem_extension: dict
storage_type: str
user_rag_memory_id: str
llm_id: str
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
- retrieve:dict
+ retrieve: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict
SummaryFails: dict
summary: dict
+
+
class COUNTState:
"""
工作流对话检索内容计数器
@@ -99,6 +103,7 @@ class COUNTState:
self.total = 0
print("[COUNTState] 已重置为 0")
+
def deduplicate_entries(entries):
seen = set()
deduped = []
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
deduped.append(entry)
return deduped
+
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
return [convert_extended_question_to_question(item) for item in data]
else:
# 其他类型直接返回
- return data
\ No newline at end of file
+ return data
diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py
index 024e320a..147ed777 100644
--- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py
+++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
-from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
+from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__)
-
class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM"""
def __init__(
- self,
- llm_client: OpenAIClient,
- ontology_types: Optional[OntologyTypeList] = None,
- language: str = "zh"):
+ self,
+ llm_client: OpenAIClient,
+ ontology_types: Optional[OntologyTypeList] = None,
+ language: str = "zh"
+ ):
"""Initialize the TripletExtractor with an LLM client
Args:
@@ -65,7 +65,8 @@ class TripletExtractor:
# Create messages for LLM
messages = [
- {"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
+ {"role": "system",
+ "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "user", "content": prompt_content}
]
@@ -116,7 +117,8 @@ class TripletExtractor:
logger.error(f"Error processing statement: {e}", exc_info=True)
return TripletExtractionResponse(triplets=[], entities=[])
- async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
+ async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
+ str, TripletExtractionResponse]:
"""Extract triplets and entities from statements
Args:
diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py
index 68e0ffe4..4df8d55b 100644
--- a/api/app/core/memory/utils/prompt/template_render.py
+++ b/api/app/core/memory/utils/prompt/template_render.py
@@ -2,15 +2,15 @@ import os
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any
-
# Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
+
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME",
- memory_verify: bool = False,quality_assessment:bool = False,
- statement_databasets: List[str] = [],language_type:str = "zh") -> str:
+ memory_verify: bool = False, quality_assessment: bool = False,
+ statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
Returns:
Rendered prompt content as string
"""
+ if statement_databasets is None:
+ statement_databasets = []
template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
- statement_databasets: List[str] = [],language_type:str = "zh") -> str:
+ statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
Returns:
Rendered prompt content as a string.
"""
+ if statement_databasets is None:
+ statement_databasets = []
template = prompt_env.get_template("reflexion.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema,
- baseline=baseline,memory_verify=memory_verify,
- statement_databasets=statement_databasets,language_type=language_type)
+ baseline=baseline, memory_verify=memory_verify,
+ statement_databasets=statement_databasets, language_type=language_type)
return rendered_prompt
diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py
index dba6717d..4a453c6b 100644
--- a/api/app/core/models/base.py
+++ b/api/app/core/models/base.py
@@ -1,23 +1,19 @@
from __future__ import annotations
-import asyncio
import os
-import time
-from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional, TypeVar
+from typing import Any, Dict, Optional, TypeVar
+
+from langchain_aws import ChatBedrock
+from langchain_community.chat_models import ChatTongyi
+from langchain_core.embeddings import Embeddings
+from langchain_core.language_models import BaseLLM
+from langchain_ollama import OllamaLLM
+from langchain_openai import ChatOpenAI, OpenAI
+from pydantic import BaseModel, Field
-import httpx
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType
-from langchain_community.document_compressors import JinaRerank
-from langchain_core.callbacks import CallbackManagerForLLMRun
-from langchain_core.embeddings import Embeddings
-from langchain_core.language_models import BaseLanguageModel, BaseLLM
-from langchain_core.outputs import Generation, LLMResult
-from langchain_core.retrievers import BaseRetriever
-from langchain_core.runnables import RunnableSerializable
-from pydantic import BaseModel, Field
T = TypeVar("T")
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
- from langchain_openai import ChatOpenAI
return ChatOpenAI
-
- if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
+ if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM:
- from langchain_openai import OpenAI
return OpenAI
elif type == ModelType.CHAT:
- from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
- from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
- from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
- from langchain_aws import ChatBedrock, ChatBedrockConverse
-
return ChatBedrock
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py
index 39c7887b..0e3fecee 100644
--- a/api/app/core/workflow/nodes/base_node.py
+++ b/api/app/core/workflow/nodes/base_node.py
@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput
+from app.schemas.model_schema import ModelInfo
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -620,11 +621,12 @@ class BaseNode(ABC):
@staticmethod
async def process_message(
- provider: str,
- is_omni: bool,
+ api_config: ModelInfo,
content: str | dict | FileObject,
+ end_user_id: str,
enable_file=False
) -> list | str | None:
+ provider = api_config.provider
if isinstance(content, dict):
content = FileObject(
type=content.get("type"),
@@ -643,7 +645,7 @@ class BaseNode(ABC):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
- multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
+ multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput(
type=content.type,
url=content.url,
@@ -653,7 +655,8 @@ class BaseNode(ABC):
)
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
- [file_obj]
+ end_user_id,
+ [file_obj],
)
content.set_content(file_obj.get_content())
if message:
diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py
index 29f7085b..7e98efab 100644
--- a/api/app/core/workflow/nodes/if_else/node.py
+++ b/api/app/core/workflow/nodes/if_else/node.py
@@ -5,7 +5,7 @@ from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
-from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
+from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
"output": VariableType.STRING
}
+ def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
+ result = []
+ for case in self.typed_config.cases:
+ expressions = []
+ for expression in case.expressions:
+ expressions.append({
+ "left": self.get_variable(expression.left, variable_pool, strict=False),
+ "right": expression.right
+ if expression.input_type == ValueInputType.CONSTANT
+ else self.get_variable(expression.right, variable_pool, strict=False),
+ "operator": expression.operator,
+ })
+ result.append({
+ "expressions": expressions,
+ "logical_operator": case.logical_operator,
+ })
+ return {
+ "cases": result
+ }
+
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:
diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py
index 696298eb..14f789a9 100644
--- a/api/app/core/workflow/nodes/knowledge/node.py
+++ b/api/app/core/workflow/nodes/knowledge/node.py
@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING
}
+ def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
+ return {
+ "query": self._render_template(self.typed_config.query, variable_pool),
+ "knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
+ }
+
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""
diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py
index 186c204f..b293d1f4 100644
--- a/api/app/core/workflow/nodes/llm/node.py
+++ b/api/app/core/workflow/nodes/llm/node.py
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context
from app.models import ModelType
+from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
# 在 Session 关闭前提取所有需要的数据
api_config = self.model_balance(config)
- model_name = api_config.model_name
- provider = api_config.provider
- api_key = api_config.api_key
- api_base = api_config.api_base
- is_omni = api_config.is_omni
- model_type = config.type
+ model_info = ModelInfo(
+ model_name=api_config.model_name,
+ model_type=ModelType(config.type),
+ api_key=api_config.api_key,
+ api_base=api_config.api_base,
+ provider=api_config.provider,
+ is_omni=api_config.is_omni,
+ capability=api_config.capability
+ )
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
llm = RedBearLLM(
RedBearModelConfig(
- model_name=model_name,
- provider=provider,
- api_key=api_key,
- base_url=api_base,
+ model_name=model_info.model_name,
+ provider=model_info.provider,
+ api_key=model_info.api_key,
+ base_url=model_info.api_base,
extra_params=extra_params,
- is_omni=is_omni
+ is_omni=model_info.is_omni
),
- type=ModelType(model_type)
+ type=model_info.model_type
)
- logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
+ logger.debug(
+ f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
-
+ user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
"role": "system",
- "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
+ "content": await self.process_message(
+ model_info,
+ content,
+ user_id,
+ self.typed_config.vision,
+ )
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
- "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
- "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
- "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
+ "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
- content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
+ content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
if messages and messages[-1]["role"] == 'user':
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
- content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
+ content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
{"role": message["role"], "content": file_content}
)
else:
- message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
+ message["content"] = await self.process_message(
+ model_info,
+ message["content"],
+ user_id,
+ self.typed_config.vision
+ )
history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
# 调用 LLM(流式,支持字符串或消息列表)
last_meta_data = {}
- async for chunk in llm.astream(self.messages, stream_usage=True):
+ async for chunk in llm.astream(self.messages):
# 提取内容
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py
index 700ed85f..acac09e4 100644
--- a/api/app/core/workflow/nodes/parameter_extractor/node.py
+++ b/api/app/core/workflow/nodes/parameter_extractor/node.py
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
}
return None
+ def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
+ return {
+ "text": self._render_template(self.typed_config.text, variable_pool),
+ "prompt": self._render_template(self.typed_config.prompt, variable_pool),
+ "params": [param.model_dump(mode="json") for param in self.typed_config.params],
+ "model_id": str(self.typed_config.model_id),
+ }
+
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:
diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py
index cafb18d4..9fed7c5d 100644
--- a/api/app/models/memory_perceptual_model.py
+++ b/api/app/models/memory_perceptual_model.py
@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base
+from app.schemas import FileType
class PerceptualType(IntEnum):
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
TEXT = 3
CONVERSATION = 4
+ @staticmethod
+ def trans_from_file_type(file_type: FileType | str):
+ type_map = {
+ FileType.IMAGE: PerceptualType.VISION,
+ FileType.AUDIO: PerceptualType.AUDIO,
+ FileType.VIDEO: PerceptualType.VISION,
+ FileType.DOCUMENT: PerceptualType.TEXT
+ }
+ return type_map.get(file_type, PerceptualType.TEXT)
+
class FileStorageService(IntEnum):
LOCAL = 1
diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py
index 9fa9536e..9077af03 100644
--- a/api/app/repositories/memory_perceptual_repository.py
+++ b/api/app/repositories/memory_perceptual_repository.py
@@ -2,7 +2,7 @@ import uuid
from datetime import datetime
from typing import List, Tuple, Optional
-from sqlalchemy import and_, desc
+from sqlalchemy import and_, desc, select
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
raise
+ def get_by_url(
+ self,
+ file_url: str
+ ) -> list[MemoryPerceptualModel]:
+ try:
+ stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
+ return list(self.db.execute(stmt).scalars())
+ except Exception:
+ db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
+ raise
+
def get_by_type(
self,
end_user_id: uuid.UUID,
diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py
index 7dfefe01..c9b741ef 100644
--- a/api/app/schemas/memory_perceptual_schema.py
+++ b/api/app/schemas/memory_perceptual_schema.py
@@ -1,5 +1,4 @@
import uuid
-from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
class Content(BaseModel):
- summary: str
keywords: list[str]
topic: str
domain: str
diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py
index 4f3878ce..058f082d 100644
--- a/api/app/schemas/model_schema.py
+++ b/api/app/schemas/model_schema.py
@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
is_official: Optional[bool] = Field(None, description="是否官方模型")
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
+
+class ModelInfo(BaseModel):
+ """模型信息Schema"""
+ model_name: str = Field(..., description="模型名称")
+ provider: str = Field(..., description="模型提供商")
+ api_key: str = Field(..., description="API密钥")
+ api_base: str = Field(..., description="API基础URL")
+ is_omni: bool = Field(default=False, description="是否为omni模型")
+ model_type: ModelType = Field(..., description="模型类型")
+ capability: List[str] = Field(default_factory=list, description="模型能力列表")
+
diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py
index f3cdde2a..9b2b2a77 100644
--- a/api/app/services/app_chat_service.py
+++ b/api/app/services/app_chat_service.py
@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
-from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
-from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db
-from app.models import MultiAgentConfig, AgentConfig
+from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
+from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
-from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
- AgentRunService
-from app.services.draft_run_service import create_web_search_tool
+from app.services.draft_run_service import AgentRunService
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
-from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -126,8 +122,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
- multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
- processed_files = await multimodal_service.process_files(files)
+ model_info = ModelInfo(
+ model_name=api_key_obj.model_name,
+ provider=api_key_obj.provider,
+ api_key=api_key_obj.api_key,
+ api_base=api_key_obj.api_base,
+ capability=api_key_obj.capability,
+ is_omni=api_key_obj.is_omni,
+ model_type=ModelType.LLM
+ )
+ multimodal_service = MultimodalService(self.db, model_info)
+ processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent(支持多模态)
@@ -266,8 +271,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
- multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
- processed_files = await multimodal_service.process_files(files)
+ model_info = ModelInfo(
+ model_name=api_key_obj.model_name,
+ provider=api_key_obj.provider,
+ api_key=api_key_obj.api_key,
+ api_base=api_key_obj.api_base,
+ capability=api_key_obj.capability,
+ is_omni=api_key_obj.is_omni,
+ model_type=ModelType.LLM
+ )
+ multimodal_service = MultimodalService(self.db, model_info)
+ processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent(支持多模态)
diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py
index d7914db5..b3b136a1 100644
--- a/api/app/services/draft_run_service.py
+++ b/api/app/services/draft_run_service.py
@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
-from app.models import AgentConfig, ModelConfig
+from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
+from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
@@ -501,9 +502,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
+ model_info = ModelInfo(
+ model_name=api_key_config["model_name"],
+ provider=api_key_config["provider"],
+ api_key=api_key_config["api_key"],
+ api_base=api_key_config["api_base"],
+ capability=api_key_config["capability"],
+ is_omni=api_key_config["is_omni"],
+ model_type=ModelType.LLM
+ )
provider = api_key_config.get("provider", "openai")
- multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
- processed_files = await multimodal_service.process_files(files)
+ multimodal_service = MultimodalService(self.db, model_info)
+ processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
# 7. 知识库检索
@@ -704,9 +714,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
+ model_info = ModelInfo(
+ model_name=api_key_config["model_name"],
+ provider=api_key_config["provider"],
+ api_key=api_key_config["api_key"],
+ api_base=api_key_config["api_base"],
+ capability=api_key_config["capability"],
+ is_omni=api_key_config["is_omni"],
+ model_type=ModelType.LLM
+ )
provider = api_key_config.get("provider", "openai")
- multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
- processed_files = await multimodal_service.process_files(files)
+ multimodal_service = MultimodalService(self.db, model_info)
+ processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
# 7. 知识库检索
@@ -841,7 +860,8 @@ class AgentRunService:
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id,
- "is_omni": api_key.is_omni
+ "is_omni": api_key.is_omni,
+ "capability": api_key.capability
}
async def _ensure_conversation(
diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py
index a20b968a..1e1d9e45 100644
--- a/api/app/services/memory_agent_service.py
+++ b/api/app/services/memory_agent_service.py
@@ -274,7 +274,7 @@ class MemoryAgentService:
Args:
end_user_id: Group identifier (also used as end_user_id)
- message: Message to write
+ messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py
index b9d96a0b..53d935fe 100644
--- a/api/app/services/memory_perceptual_service.py
+++ b/api/app/services/memory_perceptual_service.py
@@ -1,19 +1,27 @@
+import os
import uuid
from typing import Dict, Any, Optional
+from urllib.parse import urlparse, unquote
+import json_repair
+from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
+from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
+from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
+from app.schemas import FileType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal
)
+from app.schemas.model_schema import ModelInfo
business_logger = get_business_logger()
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
"keywords": content.keywords,
"topic": content.topic,
"domain": content.domain,
- "created_time": int(memory.created_time.timestamp()*1000),
+ "created_time": int(memory.created_time.timestamp() * 1000),
**detail
}
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
return result
except Exception as e:
- business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
+ business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
+ exc_info=True)
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR)
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
for memory in memories:
meta_data = memory.meta_data or {}
content = meta_data.get("content", {})
-
+
# 安全地提取 content 字段,提供默认值
if content:
content_obj = Content(**content)
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
topic = "Unknown"
domain = "Unknown"
keywords = []
-
+
memory_item = PerceptualMemoryItem(
id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type),
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
topic=topic,
domain=domain,
keywords=keywords,
- created_time=int(memory.created_time.timestamp()*1000),
+ created_time=int(memory.created_time.timestamp() * 1000),
storage_service=FileStorageService(memory.storage_service),
)
memory_items.append(memory_item)
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
+
+ async def generate_perceptual_memory(
+ self,
+ end_user_id: str,
+ model_config: ModelInfo,
+ file_type: str,
+ file_url: str,
+ file_message: dict,
+ ):
+ memories = self.repository.get_by_url(file_url)
+ if memories:
+ business_logger.info(f"Perceptual memory already exists: {file_url}")
+ if end_user_id not in [memory.end_user_id for memory in memories]:
+ business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
+ memory_cache = memories[0]
+ self.repository.create_perceptual_memory(
+ end_user_id=uuid.UUID(end_user_id),
+ perceptual_type=PerceptualType(memory_cache.perceptual_type),
+ file_path=memory_cache.file_path,
+ file_name=memory_cache.file_name,
+ file_ext=memory_cache.file_ext,
+ summary=memory_cache.summary,
+ meta_data=memory_cache.meta_data
+ )
+ self.db.commit()
+
+ return
+ llm = RedBearLLM(RedBearModelConfig(
+ model_name=model_config.model_name,
+ provider=model_config.provider,
+ api_key=model_config.api_key,
+ base_url=model_config.api_base,
+ is_omni=model_config.is_omni
+ ), type=model_config.model_type)
+ try:
+ prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
+ with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
+ opt_system_prompt = f.read()
+ rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
+ except FileNotFoundError:
+ raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
+ messages = [
+ {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
+ {"role": RoleType.USER.value, "content": [
+ {"type": "text", "text": "Summarize the following file"}, file_message
+ ]}
+ ]
+ result = await llm.ainvoke(messages)
+ content = json_repair.repair_json(result.content, return_objects=True)
+ path = urlparse(file_url).path
+ filename = os.path.basename(path)
+ filename = unquote(filename)
+ file_ext = os.path.splitext(filename)[1]
+ if not file_ext:
+ if file_type == FileType.AUDIO:
+ file_ext = ".mp3"
+ elif file_type == FileType.VIDEO:
+ file_ext = ".mp4"
+ elif file_type == FileType.DOCUMENT:
+ file_ext = ".txt"
+ elif file_type == FileType.IMAGE:
+ file_ext = ".jpg"
+ filename += file_ext
+ file_content = {
+ "keywords": content.get("keywords", []),
+ "topic": content.get("topic"),
+ "domain": content.get("domain")
+ }
+ if file_type in [FileType.IMAGE, FileType.VIDEO]:
+ file_modalities = {
+ "scene": content.get("scene")
+ }
+ elif file_type in [FileType.DOCUMENT]:
+ file_modalities = {
+ "section_count": content.get("section_count"),
+ "title": content.get("title"),
+ "first_line": content.get("first_line")
+ }
+ else:
+ file_modalities = {
+ "speaker_count": content.get("speaker_count")
+ }
+ self.repository.create_perceptual_memory(
+ end_user_id=uuid.UUID(end_user_id),
+ perceptual_type=PerceptualType.trans_from_file_type(file_type),
+ file_path=file_url,
+ file_name=filename,
+ file_ext=file_ext,
+ summary=content.get('summary'),
+ meta_data={
+ "content": file_content,
+ "modalities": file_modalities
+ }
+ )
+ self.db.commit()
diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py
index fffca2e5..935efafe 100644
--- a/api/app/services/multimodal_service.py
+++ b/api/app/services/multimodal_service.py
@@ -10,6 +10,7 @@
"""
import base64
import io
+import uuid
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -23,9 +24,12 @@ from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
+from app.models import ModelApiKey
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
+from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
+from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -39,6 +43,7 @@ DOC_MIME = [
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
+
def __init__(self, file: FileInput):
self.file = file
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
if transcription:
return {
"type": "text",
- "text": f""
+ "text": f""
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
- """多模态文件处理服务"""
+ """
+ Service for handling multimodal file processing.
- def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
- enable_audio_transcription: bool = False, is_omni: bool = False):
+ Attributes:
+ db (Session): Database session.
+ model_api_key (str): API key for the model provider.
+ provider (str): Name of the model provider.
+ is_omni (bool): Indicates whether the model supports full multimodal capability.
+ capability (list): Capability configuration of the model.
+ audio_api_key (str | None): API key used for audio transcription.
+ enable_audio_transcription (bool): Whether audio transcription is enabled.
+ """
+
+ def __init__(
+ self,
+ db: Session,
+ api_config: ModelInfo | None = None,
+ audio_api_key: Optional[str] = None,
+ enable_audio_transcription: bool = False,
+ ):
"""
- 初始化多模态服务
-
+ Initialize the multimodal service.
+
Args:
- db: 数据库会话
- provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
- api_key: API 密钥(用于音频转文本)
- enable_audio_transcription: 是否启用音频转文本
- is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
+ db (Session): Database session.
+ api_config (ModelApiKey | None): Model API configuration.
+ audio_api_key (str | None): API key for audio transcription.
+ enable_audio_transcription (bool): Enable audio transcription.
"""
self.db = db
- self.provider = provider.lower()
- self.api_key = api_key
+ self.api_config = api_config
+ if self.api_config is not None:
+ self.model_api_key = api_config.api_key
+ self.provider = api_config.provider.lower()
+ self.is_omni = api_config.is_omni
+ self.capability = api_config.capability
+ self.audio_api_key = audio_api_key
self.enable_audio_transcription = enable_audio_transcription
- self.is_omni = is_omni
async def process_files(
self,
- files: Optional[List[FileInput]]
+ end_user_id: uuid.UUID | str,
+ files: Optional[List[FileInput]],
+
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
+ end_user_id: 用户ID
files: 文件输入列表
Returns:
@@ -319,6 +346,8 @@ class MultimodalService:
"""
if not files:
return []
+ if isinstance(end_user_id, uuid.UUID):
+ end_user_id = str(end_user_id)
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -333,19 +362,25 @@ class MultimodalService:
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
+ if not file.url:
+ file.url = await self.get_file_url(file)
try:
- if file.type == FileType.IMAGE:
+ if file.type == FileType.IMAGE and "vision" in self.capability:
content = await self._process_image(file, strategy)
result.append(content)
+ self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy)
result.append(content)
- elif file.type == FileType.AUDIO:
+ self.write_perceptual_memory(end_user_id, file.type, file.url, content)
+ elif file.type == FileType.AUDIO and "audio" in self.capability:
content = await self._process_audio(file, strategy)
result.append(content)
- elif file.type == FileType.VIDEO:
+ self.write_perceptual_memory(end_user_id, file.type, file.url, content)
+ elif file.type == FileType.VIDEO and "video" in self.capability:
content = await self._process_video(file, strategy)
result.append(content)
+ self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
@@ -355,7 +390,8 @@ class MultimodalService:
"file_index": idx,
"file_type": file.type,
"error": str(e)
- }
+ },
+ exc_info=True
)
# 继续处理其他文件,不中断整个流程
result.append({
@@ -366,6 +402,17 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
return result
+ def write_perceptual_memory(
+ self,
+ end_user_id: str,
+ file_type: str,
+ file_url: str,
+ file_message: dict
+ ):
+ """写入感知记忆"""
+ if end_user_id and self.api_config:
+ write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
+
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
@@ -387,43 +434,6 @@ class MultimodalService:
"text": f"[图片处理失败: {str(e)}]"
}
- @staticmethod
- async def _download_and_encode_image(url: str) -> tuple[str, str]:
- """
- 下载图片并转换为 base64
-
- Args:
- url: 图片 URL
-
- Returns:
- tuple: (base64_data, media_type)
- """
- from mimetypes import guess_type
-
- # 下载图片
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.get(url)
- response.raise_for_status()
-
- # 获取图片数据
- image_data = response.content
-
- # 确定 media type
- content_type = response.headers.get("content-type")
- if content_type and content_type.startswith("image/"):
- media_type = content_type
- else:
- # 从 URL 推断
- guessed_type, _ = guess_type(url)
- media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
-
- # 转换为 base64
- base64_data = base64.b64encode(image_data).decode("utf-8")
-
- logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
-
- return base64_data, media_type
-
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件(PDF、Word 等)
@@ -436,7 +446,6 @@ class MultimodalService:
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
- # 远程文档暂不支持提取
return {
"type": "text",
"text": f"\n{await self._extract_document_text(file)}\n"
@@ -471,12 +480,12 @@ class MultimodalService:
# 如果启用音频转文本且有 API Key
transcription = None
- if self.enable_audio_transcription and self.api_key:
+ if self.enable_audio_transcription and self.audio_api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
- transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
+ transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
elif self.provider == "openai":
- transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
+ transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
diff --git a/api/app/services/prompt/perceptual_summary_system.jinja2 b/api/app/services/prompt/perceptual_summary_system.jinja2
new file mode 100644
index 00000000..ee5d3eb5
--- /dev/null
+++ b/api/app/services/prompt/perceptual_summary_system.jinja2
@@ -0,0 +1,53 @@
+{% raw %}You are a professional information extraction system.
+
+Your task is to analyze the provided document content and generate structured metadata.
+
+Extract the following fields:
+
+* **summary**: A concise summary of the document in 2–4 sentences.
+* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
+* **topic**: The primary topic of the document expressed as a short phrase (3–8 words).
+* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
+
+STRICT RULES:
+
+1. Output MUST be valid JSON.
+2. Do NOT output markdown.
+3. Do NOT output explanations.
+4. Do NOT output any text before or after the JSON.
+5. The JSON MUST contain EXACTLY these four keys:
+ * summary
+ * keywords
+ * topic
+ * domain{% endraw %}
+{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
+{% if file_type == 'audio' %} * speaker_count {% endif %}
+{% if file_type == 'document' %} * section_count
+ * title
+ * first_line
+{% endif %}
+{% raw %}
+6. `keywords` MUST be a JSON array of strings.
+7. If the document content is insufficient, infer the best possible answer based on context.
+8. Ensure the JSON is syntactically correct.
+{% endraw %}
+9. Output using the language {{ language }}
+{% raw %}
+Required JSON format:
+
+{
+"summary": "string",
+"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
+"topic": "string",
+"domain": "string",
+{% endraw %}
+{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
+{% if file_type == 'document' %} "section_count": integer
+"title": "string",
+"first_line": "string"
+{% endif %}
+{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
+{% raw %}
+}
+
+Now analyze the following document and return the JSON result.{% endraw %}
diff --git a/api/app/tasks.py b/api/app/tasks.py
index 6fd9c954..5e1550bd 100644
--- a/api/app/tasks.py
+++ b/api/app/tasks.py
@@ -1,6 +1,5 @@
import asyncio
-import json
-import logging
+import hashlib
import os
import re
import shutil
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional
-from uuid import UUID
import redis
-import requests
from redis.exceptions import RedisError
-logger = logging.getLogger(__name__)
+# Import a unified Celery instance
+from app.celery_app import celery_app
+from app.core.config import settings
+from app.core.logging_config import get_logger
+from app.core.rag.crawler.web_crawler import WebCrawler
+from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
+from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
+from app.core.rag.integrations.feishu.client import FeishuAPIClient
+from app.core.rag.integrations.feishu.models import FileInfo
+from app.core.rag.integrations.yuque.client import YuqueAPIClient
+from app.core.rag.integrations.yuque.models import YuqueDocInfo
+from app.core.rag.llm.chat_model import Base
+from app.core.rag.llm.cv_model import QWenCV
+from app.core.rag.llm.embedding_model import OpenAIEmbed
+from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
+from app.core.rag.models.chunk import DocumentChunk
+from app.core.rag.prompts.generator import question_proposal
+from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
+ ElasticSearchVectorFactory,
+)
+from app.db import get_db, get_db_context
+from app.models import Document, File, Knowledge
+from app.schemas import document_schema, file_schema
+from app.schemas.model_schema import ModelInfo
+from app.services.memory_agent_service import MemoryAgentService
+from app.services.memory_perceptual_service import MemoryPerceptualService
+from app.utils.config_utils import resolve_config_id
+from app.utils.redis_lock import RedisLock
+
+logger = get_logger(__name__)
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
-_sync_redis_pool: redis.ConnectionPool = None
+_sync_redis_pool: redis.ConnectionPool | None = None
-def _get_or_create_redis_pool() -> redis.ConnectionPool:
+
+def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
"""获取或创建 Redis 连接池(懒初始化)"""
global _sync_redis_pool
if _sync_redis_pool is None:
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
return None
return _sync_redis_pool
+
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
"""获取同步 Redis 客户端(使用连接池)
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
pool = _get_or_create_redis_pool()
if pool is None:
return None
-
+
client = redis.StrictRedis(connection_pool=pool)
# 验证连接可用性
client.ping()
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
return None
-# Import a unified Celery instance
-from app.celery_app import celery_app
-from app.core.config import settings
-from app.core.rag.crawler.web_crawler import WebCrawler
-from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
-from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
-from app.core.rag.integrations.feishu.client import FeishuAPIClient
-from app.core.rag.integrations.feishu.models import FileInfo
-from app.core.rag.integrations.yuque.client import YuqueAPIClient
-from app.core.rag.integrations.yuque.models import YuqueDocInfo
-from app.core.rag.llm.chat_model import Base
-from app.core.rag.llm.cv_model import QWenCV
-from app.core.rag.llm.embedding_model import OpenAIEmbed
-from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
-from app.core.rag.models.chunk import DocumentChunk
-from app.core.rag.prompts.generator import question_proposal
-from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
- ElasticSearchVectorFactory,
-)
-from app.db import get_db, get_db_context
-from app.models.document_model import Document
-from app.models.file_model import File
-from app.models.knowledge_model import Knowledge
-from app.schemas import document_schema, file_schema
-from app.services.memory_agent_service import MemoryAgentService
-from app.utils.config_utils import resolve_config_id
+
+def set_asyncio_event_loop():
+ """Set the asyncio event loop for the current thread."""
+ try:
+ loop = asyncio.get_event_loop()
+ if loop.is_closed():
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ return loop
@celery_app.task(name="tasks.process_item")
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
vector_size = len(vts[0])
init_graphrag(task, vector_size)
- async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
- chat_model, embedding_model, callback, with_resolution: bool = True,
- with_community: bool = True, ) -> dict:
+ async def _run(
+ row: dict,
+ document_ids: list[str],
+ language: str,
+ parser_config: dict,
+ vector_service,
+ chat_model,
+ embedding_model,
+ callback,
+ with_resolution: bool = True,
+ with_community: bool = True
+ ) -> dict:
await trio.sleep(5) # Delay for 10 seconds
nonlocal progress_msg # Declare the use of an external progress_msg variable
result = await run_graphrag_for_kb(
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community,
)
)
+
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
with_community=with_community,
)
)
+
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
# Log but continue - will fail later with proper error
pass
- async def _run() -> str:
+ async def _run() -> dict:
with get_db_context() as db:
service = MemoryAgentService()
- return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
- storage_type, user_rag_memory_id)
+ return await service.read_memory(
+ end_user_id,
+ message,
+ history,
+ search_switch,
+ actual_config_id, db,
+ storage_type, user_rag_memory_id
+ )
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
-def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str,
+def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
+ user_rag_memory_id: str,
language: str = "zh") -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
Args:
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
Raises:
Exception on failure
"""
- from app.core.logging_config import get_logger
- logger = get_logger(__name__)
- logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}")
+ logger.info(
+ f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
+ f"config_id={config_id} (type: {type(config_id).__name__}), "
+ f"storage_type={storage_type}, language={language}")
start_time = time.time()
# Convert config_id to UUID
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try:
with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db)
- print(100*'-')
+ print(100 * '-')
print(actual_config_id)
- print(100*'-')
+ print(100 * '-')
logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e:
- logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
+ logger.error(
+ f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
return {
"status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
async def _run() -> str:
with get_db_context() as db:
logger.info(
- f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
+ f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
+ f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
user_rag_memory_id, language)
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
return result
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
}
-def reflection_engine() -> None:
- """Empty function placeholder for timed background reflection.
-
- Intentionally left blank; replace with real reflection logic later.
- """
- import asyncio
-
- from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
-
- host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
- asyncio.run(self_reflexion(host_id))
-
-
-@celery_app.task(name="app.core.memory.agent.reflection.timer")
-def reflection_timer_task() -> None:
- """Periodic Celery task that invokes reflection_engine.
-
- Raises an exception on failure.
- """
- reflection_engine()
-
-
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]:
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"workspace_id": workspace_id,
"elapsed_time": elapsed_time,
}
+
+
@celery_app.task(
name="app.tasks.write_all_workspaces_memory_task",
bind=True,
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_api_logger
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
- api_logger = get_api_logger()
-
with get_db_context() as db:
try:
# 获取所有活跃的工作空间
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
).all()
if not workspaces:
- api_logger.warning("没有找到活跃的工作空间")
+ logger.warning("没有找到活跃的工作空间")
return {
"status": "SUCCESS",
"message": "没有找到活跃的工作空间",
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"workspace_results": []
}
- api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
+ logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
all_workspace_results = []
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
- api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
+ logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
try:
# 1. 查询当前workspace下的所有app(仅未删除的)
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
- api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
+ logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
continue
# 2. 查询所有app下的end_user_id(去重)
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
- api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
+ logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"created_at": memory_increment.created_at.isoformat(),
})
- api_logger.info(
+ logger.info(
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
)
except Exception as e:
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
- api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
+ logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
}
except Exception as e:
- api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
+ logger.error(f"记忆增量统计任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
}
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository
from app.services.user_memory_service import UserMemoryService
- logger = get_logger(__name__)
logger.info("开始执行记忆缓存重新生成定时任务")
service = UserMemoryService()
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
- api_logger = get_api_logger()
-
with get_db_context() as db:
try:
# 获取所有工作空间
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
- api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
+ logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
try:
reflection_service = MemoryReflectionService(db)
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
workspace_reflection_results = []
for data in result['apps_detailed_info']:
- if data['memory_configs'] == []:
+ if not data['memory_configs']:
continue
releases = data['releases']
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
user['app_id']):
# 调用反思服务
- api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
+ logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data(
config_data=config,
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
"reflection_results": workspace_reflection_results
})
- api_logger.info(
+ logger.info(
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e:
db.rollback() # Rollback failed transaction to allow next query
- api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
+ logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({
"workspace_id": str(workspace_id),
"error": str(e),
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
except Exception as e:
- api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
+ logger.error(f"工作空间反思任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_api_logger
from app.services.memory_forget_service import MemoryForgetService
- api_logger = get_api_logger()
-
with get_db_context() as db:
try:
- api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
+ logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
forget_service = MemoryForgetService()
# 运行遗忘周期
+ # FIXME: MemeoryForgetService
report = await forget_service.trigger_forgetting(
db=db,
end_user_id=None, # 处理所有组
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
duration = time.time() - start_time
- api_logger.info(
+ logger.info(
f"遗忘周期定时任务完成: "
f"融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, "
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
except Exception as e:
duration = time.time() - start_time
- api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
+ logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
return {
"status": "FAILED",
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
finally:
loop.close()
+
# =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from sqlalchemy import func, select
+ from sqlalchemy import select
- from app.core.logging_config import get_logger
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
- logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
total_users = 0
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
for end_user_id in refresh_iter:
logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time()
-
+
implicit_success = False
emotion_success = False
errors = []
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
failed += 1
user_elapsed = time.time() - user_start_time
-
+
# 记录用户处理结果
user_result = {
"end_user_id": end_user_id,
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
}
try:
- # 使用 nest_asyncio 来避免事件循环冲突
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
# 尝试获取现有事件循环,如果不存在则创建新的
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
- logger = get_logger(__name__)
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
initialized = 0
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
}
try:
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
默认生成中文(zh)兴趣分布数据。
Args:
+ self: task object
end_user_ids: 需要检查的用户ID列表
Returns:
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
start_time = time.time()
async def _run() -> Dict[str, Any]:
- from app.core.logging_config import get_logger
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService
- logger = get_logger(__name__)
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
}
try:
- try:
- import nest_asyncio
- nest_asyncio.apply()
- except ImportError:
- pass
-
- try:
- loop = asyncio.get_event_loop()
- if loop.is_closed():
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
+ loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
@@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
+
+
+@celery_app.task(
+ name="app.tasks.write_perceptual_memory",
+ bind=True,
+ ignore_result=True,
+ max_retries=0,
+ acks_late=False,
+ time_limit=3600,
+ soft_time_limit=3300,
+)
+def write_perceptual_memory(
+ self,
+ end_user_id: str,
+ model_api_config: dict,
+ file_type: str,
+ file_url: str,
+ file_message: dict
+):
+ """
+ Write perceptual memory for a user into PostgreSQL and Neo4j.
+
+ This task generates or updates the user's perceptual memory
+ in the backend databases. It is intended to be executed asynchronously
+ via Celery.
+
+ Args:
+ end_user_id (uuid.UUID): The unique identifier of the end user.
+ model_api_config (ModelInfo): API configuration for the model
+ used to generate perceptual memory.
+ file_type (str): The file type
+ file_url (url): The url of file
+ file_message (dict): The file message containing details about the file
+ to be processed.
+
+ Returns:
+ None
+ """
+ file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
+ set_asyncio_event_loop()
+ with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
+ model_info = ModelInfo(**model_api_config)
+ with get_db_context() as db:
+ memory_perceptual_service = MemoryPerceptualService(db)
+ return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
+ end_user_id,
+ model_info,
+ file_type,
+ file_url,
+ file_message,
+ ))
diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py
new file mode 100644
index 00000000..99f62d84
--- /dev/null
+++ b/api/app/utils/redis_lock.py
@@ -0,0 +1,61 @@
+import redis
+import uuid
+import time
+
+UNLOCK_SCRIPT = """
+if redis.call("get", KEYS[1]) == ARGV[1] then
+ return redis.call("del", KEYS[1])
+else
+ return 0
+end
+"""
+
+
+class RedisLock:
+ def __init__(
+ self,
+ key: str,
+ redis_client: redis.StrictRedis,
+ expire: int = 60,
+ retry_interval: float = 0.1,
+ timeout: float = 30
+
+ ):
+ self.key = key
+ self.expire = expire
+ self.value = str(uuid.uuid4())
+ self._locked = False
+ self.retry_interval = retry_interval
+ self.timeout = timeout
+ self.redis_client = redis_client
+
+ def acquire(self) -> bool:
+ start = time.time()
+ while True:
+ ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
+ if ok:
+ self._locked = True
+ return True
+ if time.time() - start >= self.timeout:
+ return False
+ time.sleep(self.retry_interval)
+
+ def release(self):
+ if not self._locked:
+ return
+ self.redis_client.eval(
+ UNLOCK_SCRIPT,
+ 1,
+ self.key,
+ self.value
+ )
+ self._locked = False
+
+ def __enter__(self):
+ ok = self.acquire()
+ if not ok:
+ raise RuntimeError(f"Get redis lock timeout: {self.key}")
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.release()