feat(multimodel): support multimodal memory display and improve code style

This commit is contained in:
Eternity
2026-03-13 13:33:58 +08:00
parent cbc8714414
commit b71bc1f875
31 changed files with 877 additions and 543 deletions

View File

@@ -1,10 +1,11 @@
import os
import asyncio import asyncio
import json import json
import logging import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import redis.asyncio as redis import redis.asyncio as redis
from redis.asyncio import ConnectionPool from redis.asyncio import ConnectionPool
from app.core.config import settings from app.core.config import settings
# 设置日志记录器 # 设置日志记录器

View File

@@ -62,7 +62,7 @@ celery_app.conf.update(
task_serializer='json', task_serializer='json',
accept_content=['json'], accept_content=['json'],
result_serializer='json', result_serializer='json',
# 时区 # 时区
timezone='Asia/Shanghai', timezone='Asia/Shanghai',
enable_utc=False, enable_utc=False,
@@ -70,43 +70,44 @@ celery_app.conf.update(
# 任务追踪 # 任务追踪
task_track_started=True, task_track_started=True,
task_ignore_result=False, task_ignore_result=False,
# 超时设置 # 超时设置
task_time_limit=3600, # 60分钟硬超时 task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时 task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line) # Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# 结果过期时间 # 结果过期时间
result_expires=3600, # 结果保存1小时 result_expires=3600, # 结果保存1小时
# 任务确认设置 # 任务确认设置
task_acks_late=True, task_acks_late=True,
task_reject_on_worker_lost=True, task_reject_on_worker_lost=True,
worker_disable_rate_limits=True, worker_disable_rate_limits=True,
# FLower setting # FLower setting
worker_send_task_events=True, worker_send_task_events=True,
task_send_sent_event=True, task_send_sent_event=True,
# task routing # task routing
task_routes={ task_routes={
# Memory tasks → memory_tasks queue (threads worker) # Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies) # Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker) # Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
@@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab(
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE, minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
) )
#构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { beat_schedule_config = {
"run-workspace-reflection": { "run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task", "task": "app.tasks.workspace_reflection_task",

View File

@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务 # 导入任务模块以注册任务
import app.tasks import app.tasks
__all__ = ['celery_app'] __all__ = ['celery_app']

View File

@@ -808,6 +808,15 @@ async def draft_run_compare(
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id) service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
)
payload.user_id = str(new_end_user.id)
# 2. 获取 Agent 配置 # 2. 获取 Agent 配置
from sqlalchemy import select from sqlalchemy import select
from app.models import AgentConfig from app.models import AgentConfig
@@ -853,6 +862,8 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
}) })
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
@@ -864,7 +875,7 @@ async def draft_run_compare(
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id,
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
@@ -895,7 +906,7 @@ async def draft_run_compare(
message=payload.message, message=payload.message,
workspace_id=workspace_id, workspace_id=workspace_id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id), user_id=payload.user_id,
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,

View File

@@ -4,13 +4,13 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState: def content_input_node(state: ReadState) -> ReadState:
""" """
Start node - Extract content and maintain state information Start node - Extract content and maintain state information
Extracts the content from the first message in the state and returns it Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information. as the data field while preserving all other state information.
Args: Args:
state: ReadState containing messages and other state data state: ReadState containing messages and other state data
Returns: Returns:
ReadState: Updated state with extracted content in data field ReadState: Updated state with extracted content in data field
""" """
@@ -19,19 +19,20 @@ def content_input_node(state: ReadState) -> ReadState:
# Return content and maintain all state information # Return content and maintain all state information
return {"data": content} return {"data": content}
def content_input_write(state: WriteState) -> WriteState: def content_input_write(state: WriteState) -> WriteState:
""" """
Start node - Extract content and maintain state information for write operations Start node - Extract content and maintain state information for write operations
Extracts the content from the first message in the state for write operations. Extracts the content from the first message in the state for write operations.
Args: Args:
state: WriteState containing messages and other state data state: WriteState containing messages and other state data
Returns: Returns:
WriteState: Updated state with extracted content in data field WriteState: Updated state with extracted content in data field
""" """
content = state['messages'][0].content if state.get('messages') else '' content = state['messages'][0].content if state.get('messages') else ''
# Return content and maintain all state information # Return content and maintain all state information
return {"data": content} return {"data": content}

View File

@@ -1,13 +1,13 @@
from typing import Literal from typing import Literal
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
counter = COUNTState(limit=3) counter = COUNTState(limit=3)
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
""" """
Determine routing based on search_switch value. Determine routing based on search_switch value.
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
return 'Input_Summary' return 'Input_Summary'
return 'Split_The_Problem' # 默认情况 return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
""" """
Determine routing based on search_switch value. Determine routing based on search_switch value.
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1': elif search_switch == '1':
return 'Retrieve_Summary' return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status'] status = state.get('verify', '')['status']
# loop_count = counter.get_total() # loop_count = counter.get_total()
if "success" in status: if "success" in status:
# counter.reset() # counter.reset()
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
# if loop_count < 2: # Maximum loop count is 3 # if loop_count < 2: # Maximum loop count is 3
# return "content_input" # return "content_input"
# else: # else:
# counter.reset() # counter.reset()
return "Summary_fails" return "Summary_fails"
else: else:
# Add default return value to avoid returning None # Add default return value to avoid returning None

View File

@@ -2,32 +2,32 @@ import json
import os import os
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
""" """
Write messages to RAG storage system Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args: Args:
end_user_id: User identifier for the conversation end_user_id: User identifier for the conversation
user_message: User's input message content user_message: User's input message content
@@ -38,14 +38,24 @@ async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory
combined_message = f"user: {user_message}\nassistant: {ai_message}" combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id) await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
async def write(
storage_type,
end_user_id,
user_message,
ai_message,
user_rag_memory_id,
actual_end_user_id,
actual_config_id,
long_term_messages=None
):
""" """
Write memory with structured message support Write memory with structured message support
Handles memory writing operations for different storage types (Neo4j/RAG). Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing. Supports both individual message pairs and batch long-term message processing.
Args: Args:
storage_type: Storage type identifier ("neo4j" or "rag") storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: Terminal user identifier end_user_id: Terminal user identifier
@@ -55,7 +65,7 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
actual_end_user_id: Actual user identifier for storage actual_end_user_id: Actual user identifier for storage
actual_config_id: Configuration identifier actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing long_term_messages: Optional list of structured messages for batch processing
Logic explanation: Logic explanation:
- RAG mode: Combines user_message and ai_message into string format, maintains original logic - RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j mode: Uses structured message lists - Neo4j mode: Uses structured message lists
@@ -64,8 +74,9 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
3. Each message is converted to independent Chunk, preserving speaker field 3. Each message is converted to independent Chunk, preserving speaker field
""" """
db = next(get_db()) if long_term_messages is None:
try: long_term_messages = []
with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db) actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j mode: Use structured message lists # Neo4j mode: Use structured message lists
structured_messages = [] structured_messages = []
@@ -105,17 +116,16 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id)) write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
""" """
Save long-term memory data to database Save long-term memory data to database
Handles the storage of long-term memory data based on different strategies Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term (chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage. to long-term memory storage.
Args: Args:
long_term_messages: Long-term message data to be saved long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings actual_config_id: Configuration identifier for memory settings
@@ -126,13 +136,12 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
with get_db_context() as db_session: with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session) repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id) result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict") data = await format_parsing(result, "dict")
chunk_data = data[:scope] chunk_data = data[:scope]
if len(chunk_data)==scope: if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data) repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------') logger.info(f'---------写入短长期-----------')
else: else:
@@ -142,22 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
logger.info(f'写入短长期:') logger.info(f'写入短长期:')
"""Window-based dialogue processing""" """Window-based dialogue processing"""
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
""" """
Process dialogue based on window size and write to Neo4j Process dialogue based on window size and write to Neo4j
Manages conversation data based on a sliding window approach. When the window Manages conversation data based on a sliding window approach. When the window
reaches the specified scope size, it triggers long-term memory storage to Neo4j. reaches the specified scope size, it triggers long-term memory storage to Neo4j.
Args: Args:
end_user_id: Terminal user identifier end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings memory_config: Memory configuration object containing settings
langchain_messages: Original message data list langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage scope: Window size determining when to trigger long-term storage
""" """
scope=scope scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id) is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False: if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -174,42 +184,53 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
config_id = memory_config.config_id config_id = memory_config.config_id
else: else:
config_id = memory_config config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, await write(
config_id, formatted_messages) AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages) count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else: else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages) count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing""" """Time-based memory processing"""
async def memory_long_term_storage(end_user_id,memory_config,time):
async def memory_long_term_storage(end_user_id, memory_config, time):
""" """
Process memory storage based on time intervals and write to Neo4j Process memory storage based on time intervals and write to Neo4j
Retrieves Redis data based on time intervals and writes it to Neo4j for Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation. long-term storage. This function handles time-based memory consolidation.
Args: Args:
end_user_id: Terminal user identifier end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings memory_config: Memory configuration object containing settings
time: Time interval for data retrieval time: Time interval for data retrieval
""" """
long_time_data = write_store.find_user_recent_sessions(end_user_id, time) long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data) format_messages = long_time_data
messages=[] messages = []
memory_config=memory_config.config_id memory_config = memory_config.config_id
for i in format_messages: for i in format_messages:
message=json.loads(i['Query']) message = json.loads(i['Query'])
messages+= message messages += message
if format_messages!=[]: if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages) memory_config, messages)
"""Aggregation judgment processing"""
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
""" """
Aggregation judgment function: determine if input sentence and historical messages describe the same event Aggregation judgment function: determine if input sentence and historical messages describe the same event
Uses LLM-based analysis to determine whether new messages should be aggregated with existing Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval. historical data or stored as separate events. This helps optimize memory storage and retrieval.
@@ -217,11 +238,11 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
end_user_id: Terminal user identifier end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings memory_config: Memory configuration object containing LLM settings
Returns: Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output dict: Aggregation judgment result containing is_same_event flag and processed output
""" """
history = None
try: try:
# 1. Get historical session data (using new method) # 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id) result = write_store.get_all_sessions_by_end_user_id(end_user_id)
@@ -255,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
output_value = structured.output output_value = structured.output
if isinstance(output_value, list): if isinstance(output_value, list):
output_value = [ output_value = [
{"role": msg.role, "content": msg.content} {"role": msg.role, "content": msg.content}
for msg in output_value for msg in output_value
] ]
@@ -268,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
await write("neo4j", end_user_id, "", "", None, end_user_id, await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value) memory_config.config_id, output_value)
return result_dict return result_dict
except Exception as e: except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}") print(f"[aggregate_judgment] 发生错误: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return { return {
"is_same_event": False, "is_same_event": False,
"output": ori_messages, "output": ori_messages,
"messages": ori_messages, "messages": ori_messages,
"history": history if 'history' in locals() else [], "history": history if 'history' in locals() else [],
"error": str(e) "error": str(e)
} }

View File

@@ -2,26 +2,25 @@ import asyncio
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from langchain.tools import tool from langchain.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.memory.src.search import ( from app.core.memory.src.search import (
search_by_temporal, search_by_temporal,
search_by_keyword_temporal, search_by_keyword_temporal,
) )
def extract_tool_message_content(response): def extract_tool_message_content(response):
""" """
Extract ToolMessage content and tool names from agent response Extract ToolMessage content and tool names from agent response
Parses agent response messages to extract tool execution results and metadata. Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data. Handles JSON parsing and provides structured access to tool output data.
Args: Args:
response: Agent response dictionary containing messages response: Agent response dictionary containing messages
Returns: Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool - tool_name: Name of the executed tool
@@ -61,10 +60,10 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel): class TimeRetrievalInput(BaseModel):
""" """
Input schema for time retrieval tool Input schema for time retrieval tool
Defines the expected input parameters for time-based retrieval operations. Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters. Used for validation and documentation of tool parameters.
Attributes: Attributes:
context: User input query content for search context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user end_user_id: Group ID for filtering search results, defaults to test user
@@ -72,25 +71,26 @@ class TimeRetrievalInput(BaseModel):
context: str = Field(description="用户输入的查询内容") context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果") end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str): def create_time_retrieval_tool(end_user_id: str):
""" """
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
Creates a specialized time-based retrieval tool that searches for statements within Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results. metadata from search results.
Args: Args:
end_user_id: User identifier for scoping search results end_user_id: User identifier for scoping search results
Returns: Returns:
function: Configured TimeRetrievalWithGroupId tool function function: Configured TimeRetrievalWithGroupId tool function
""" """
def clean_temporal_result_fields(data): def clean_temporal_result_fields(data):
""" """
Clean unnecessary fields from temporal search results and modify structure Clean unnecessary fields from temporal search results and modify structure
Removes metadata fields that are not needed for end-user consumption and Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability. restructures the response format for better usability.
@@ -102,10 +102,10 @@ def create_time_retrieval_tool(end_user_id: str):
""" """
# List of fields to filter out # List of fields to filter out
fields_to_remove = { fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids' 'valid_at', 'invalid_at', 'statement_ids'
} }
if isinstance(data, dict): if isinstance(data, dict):
cleaned = {} cleaned = {}
for key, value in data.items(): for key, value in data.items():
@@ -126,15 +126,16 @@ def create_time_retrieval_tool(end_user_id: str):
return [clean_temporal_result_fields(item) for item in data] return [clean_temporal_result_fields(item) for item in data]
else: else:
return data return data
@tool @tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
end_user_id_param: str = None, clean_output: bool = True) -> str:
""" """
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
Performs time-based search operations with automatic metadata filtering. Supports Performs time-based search operations with automatic metadata filtering. Supports
flexible date range specification and provides clean, user-friendly output. flexible date range specification and provides clean, user-friendly output.
Explicit parameters: Explicit parameters:
- context: Query context content - context: Query context content
- start_date: Start time (optional, format: YYYY-MM-DD) - start_date: Start time (optional, format: YYYY-MM-DD)
@@ -142,10 +143,11 @@ def create_time_retrieval_tool(end_user_id: str):
- end_user_id_param: Group ID (optional, overrides default group ID) - end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output - clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns: Returns:
str: JSON formatted search results with temporal data str: JSON formatted search results with temporal data
""" """
async def _async_search(): async def _async_search():
# Use passed parameters or default values # Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id actual_end_user_id = end_user_id_param or end_user_id
@@ -167,18 +169,19 @@ def create_time_retrieval_tool(end_user_id: str):
cleaned_results = results cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2) return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search()) return asyncio.run(_async_search())
@tool @tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
clean_output: bool = True) -> str:
""" """
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
Performs combined keyword and temporal search operations with automatic metadata Performs combined keyword and temporal search operations with automatic metadata
filtering. Provides more targeted search results by combining content relevance filtering. Provides more targeted search results by combining content relevance
with time-based filtering. with time-based filtering.
Explicit parameters: Explicit parameters:
- context: Query content for keyword matching - context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days - days_back: Number of days to search backwards, default 7 days
@@ -186,10 +189,11 @@ def create_time_retrieval_tool(end_user_id: str):
- end_date: End time (optional, format: YYYY-MM-DD) - end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output - clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns: Returns:
str: JSON formatted search results combining keyword and temporal data str: JSON formatted search results combining keyword and temporal data
""" """
async def _async_search(): async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
@@ -212,29 +216,29 @@ def create_time_retrieval_tool(end_user_id: str):
return json.dumps(cleaned_results, ensure_ascii=False, indent=2) return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search()) return asyncio.run(_async_search())
return TimeRetrievalWithGroupId return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params): def create_hybrid_retrieval_tool_async(memory_config, **search_params):
""" """
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
Creates an advanced hybrid search tool that combines multiple search strategies Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting. (keyword, vector, hybrid) with automatic result cleaning and formatting.
Args: Args:
memory_config: Memory configuration object containing LLM and search settings memory_config: Memory configuration object containing LLM and search settings
**search_params: Search parameters including end_user_id, limit, include, etc. **search_params: Search parameters including end_user_id, limit, include, etc.
Returns: Returns:
function: Configured HybridSearch tool function with async capabilities function: Configured HybridSearch tool function with async capabilities
""" """
def clean_result_fields(data): def clean_result_fields(data):
""" """
Recursively clean unnecessary fields from results Recursively clean unnecessary fields from results
Removes metadata fields that are not needed for end-user consumption, Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size. improving readability and reducing response size.
@@ -247,11 +251,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
# List of fields to filter out # List of fields to filter out
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development # TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = { fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
} }
if isinstance(data, dict): if isinstance(data, dict):
# Clean dictionary # Clean dictionary
cleaned = {} cleaned = {}
@@ -265,7 +269,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
else: else:
# Return other types directly # Return other types directly
return data return data
@tool @tool
async def HybridSearch( async def HybridSearch(
context: str, context: str,
@@ -279,7 +283,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
) -> str: ) -> str:
""" """
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
Provides comprehensive search capabilities combining multiple search strategies Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output. with intelligent result ranking and automatic metadata filtering for clean output.
@@ -292,7 +296,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
use_forgetting_rerank: Whether to use forgetting-based reranking use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: Whether to use LLM-based reranking use_llm_rerank: Whether to use LLM-based reranking
clean_output: Whether to clean metadata fields from output clean_output: Whether to clean metadata fields from output
Returns: Returns:
str: JSON formatted comprehensive search results str: JSON formatted comprehensive search results
""" """
@@ -329,9 +333,9 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"search_type": search_type, "search_type": search_type,
"results": cleaned_results "results": cleaned_results
} }
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e: except Exception as e:
error_result = { error_result = {
"error": f"混合检索失败: {str(e)}", "error": f"混合检索失败: {str(e)}",
@@ -340,35 +344,36 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
} }
return json.dumps(error_result, ensure_ascii=False, indent=2) return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params): def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
""" """
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
Creates a synchronous wrapper around the async hybrid search functionality, Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments. making it compatible with synchronous tool execution environments.
Args: Args:
memory_config: Memory configuration object containing search settings memory_config: Memory configuration object containing search settings
**search_params: Search parameters for configuration **search_params: Search parameters for configuration
Returns: Returns:
function: Configured HybridSearchSync tool function function: Configured HybridSearchSync tool function
""" """
@tool @tool
def HybridSearchSync( def HybridSearchSync(
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
end_user_id: str = None, end_user_id: str = None,
clean_output: bool = True clean_output: bool = True
) -> str: ) -> str:
""" """
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
Provides the same hybrid search capabilities as the async version but in a Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion. synchronous execution context. Automatically handles async-to-sync conversion.
@@ -378,10 +383,11 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
limit: Result quantity limit limit: Result quantity limit
end_user_id: Group ID for filtering search results end_user_id: Group ID for filtering search results
clean_output: Whether to clean metadata fields from output clean_output: Whether to clean metadata fields from output
Returns: Returns:
str: JSON formatted search results str: JSON formatted search results
""" """
async def _async_search(): async def _async_search():
# Create async tool and execute # Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
@@ -392,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"end_user_id": end_user_id, "end_user_id": end_user_id,
"clean_output": clean_output "clean_output": clean_output
}) })
return asyncio.run(_async_search()) return asyncio.run(_async_search())
return HybridSearchSync return HybridSearchSync

View File

@@ -1,10 +1,12 @@
import json import json
from langchain_core.messages import HumanMessage, AIMessage from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
async def format_parsing(messages: list, type: str = 'string'):
""" """
Format and parse message lists into different output types Format and parse message lists into different output types
Processes message lists from storage and converts them into either string format Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization. and role-based message organization.
@@ -19,8 +21,8 @@ async def format_parsing(messages: list,type:str='string'):
- 'dict': List of dictionaries mapping user messages to AI responses - 'dict': List of dictionaries mapping user messages to AI responses
""" """
result = [] result = []
user=[] user = []
ai=[] ai = []
for message in messages: for message in messages:
hstory_messages = message['messages'] hstory_messages = message['messages']
@@ -30,37 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role'] role = content['role']
content = content['content'] content = content['content']
if type == "string": if type == "string":
if role == 'human' or role=="user": if role == 'human' or role == "user":
content = '用户:' + content content = '用户:' + content
else: else:
content = 'AI:' + content content = 'AI:' + content
result.append(content) result.append(content)
if type == "dict" : if type == "dict":
if role == 'human' or role=="user": if role == 'human' or role == "user":
user.append( content) user.append(content)
else: else:
ai.append(content) ai.append(content)
if type == "dict": if type == "dict":
for key,values in zip(user,ai): for key, values in zip(user, ai):
result.append({key:values}) result.append({key: values})
return result return result
async def messages_parse(messages: list | dict): async def messages_parse(messages: list | dict):
""" """
Parse messages from storage format into user-AI conversation pairs Parse messages from storage format into user-AI conversation pairs
Extracts and organizes conversation data from stored message format, Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage. separating user and AI messages and pairing them for database storage.
Args: Args:
messages: List or dictionary containing stored message data with Query fields messages: List or dictionary containing stored message data with Query fields
Returns: Returns:
list: List of dictionaries containing user-AI message pairs for database storage list: List of dictionaries containing user-AI message pairs for database storage
""" """
user=[] user = []
ai=[] ai = []
database=[] database = []
for message in messages: for message in messages:
Query = message['Query'] Query = message['Query']
Query = json.loads(Query) Query = json.loads(Query)
@@ -72,20 +75,20 @@ async def messages_parse(messages: list | dict):
ai.append(data['content']) ai.append(data['content'])
for key, values in zip(user, ai): for key, values in zip(user, ai):
database.append({key, values}) database.append({key, values})
return database return database
async def agent_chat_messages(user_content,ai_content): async def agent_chat_messages(user_content, ai_content):
""" """
Create structured chat message format for agent conversations Create structured chat message format for agent conversations
Formats user and AI content into a standardized message structure suitable Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects. for agent processing and storage. Creates role-based message objects.
Args: Args:
user_content: User's message content string user_content: User's message content string
ai_content: AI's response content string ai_content: AI's response content string
Returns: Returns:
list: List of structured message dictionaries with role and content fields list: List of structured message dictionaries with role and content fields
""" """

View File

@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -42,13 +41,15 @@ async def make_write_graph():
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
""" """
Handle long-term memory storage with different strategies Handle long-term memory storage with different strategies
Supports multiple storage strategies including chunk-based, time-based, Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence. and aggregate judgment approaches for long-term memory persistence.
Args: Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store langchain_messages: List of messages to store
@@ -56,9 +57,10 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
end_user_id: User group identifier end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6) scope: Scope parameter for chunk-based storage (default: 6)
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages)) write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话 # 获取数据库会话
with get_db_context() as db_session: with get_db_context() as db_session:
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
@@ -66,25 +68,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
config_id=memory_config, # 改为整数 config_id=memory_config, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type==AgentMemory_Long_Term.STRATEGY_CHUNK: if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation''' '''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope) await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type==AgentMemory_Long_Term.STRATEGY_TIME: if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy""" """Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config,AgentMemory_Long_Term.TIME_SCOPE) await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type==AgentMemory_Long_Term.STRATEGY_AGGREGATE: if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment""" """Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
""" """
Write long-term memory with different storage types Write long-term memory with different storage types
Handles both RAG-based storage and traditional memory storage approaches. Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages. For traditional storage, uses chunk-based strategy with paired user-AI messages.
Args: Args:
storage_type: Type of storage (RAG or traditional) storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier end_user_id: User group identifier
@@ -95,7 +96,7 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG: if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else: else:
@@ -128,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
# #
# if __name__ == "__main__": # if __name__ == "__main__":
# import asyncio # import asyncio
# asyncio.run(main()) # asyncio.run(main())

View File

@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3]) PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict): class WriteState(TypedDict):
''' """
Langgrapg Writing TypedDict Langgrapg Writing TypedDict
''' """
messages: Annotated[list[AnyMessage], add_messages] messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
data: str data: str
language: str # 语言类型 ("zh" 中文, "en" 英文) language: str # 语言类型 ("zh" 中文, "en" 英文)
class ReadState(TypedDict): class ReadState(TypedDict):
""" """
LangGraph 工作流状态定义 LangGraph 工作流状态定义
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
config_id: str config_id: str
data: str # 新增字段用于传递内容 data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果 spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict problem_extension: dict
storage_type: str storage_type: str
user_rag_memory_id: str user_rag_memory_id: str
llm_id: str llm_id: str
embedding_id: str embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象 memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict retrieve: dict
RetrieveSummary: dict RetrieveSummary: dict
InputSummary: dict InputSummary: dict
verify: dict verify: dict
SummaryFails: dict SummaryFails: dict
summary: dict summary: dict
class COUNTState: class COUNTState:
""" """
工作流对话检索内容计数器 工作流对话检索内容计数器
@@ -99,6 +103,7 @@ class COUNTState:
self.total = 0 self.total = 0
print("[COUNTState] 已重置为 0") print("[COUNTState] 已重置为 0")
def deduplicate_entries(entries): def deduplicate_entries(entries):
seen = set() seen = set()
deduped = [] deduped = []
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
deduped.append(entry) deduped.append(entry)
return deduped return deduped
def merge_to_key_value_pairs(data, query_key, result_key): def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list) grouped = defaultdict(list)
for item in data: for item in data:
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
return [convert_extended_question_to_question(item) for item in data] return [convert_extended_question_to_question(item) for item in data]
else: else:
# 其他类型直接返回 # 其他类型直接返回
return data return data

View File

@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
class TripletExtractor: class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM""" """Extracts knowledge triplets and entities from statements using LLM"""
def __init__( def __init__(
self, self,
llm_client: OpenAIClient, llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None, ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"): language: str = "zh"
):
"""Initialize the TripletExtractor with an LLM client """Initialize the TripletExtractor with an LLM client
Args: Args:
@@ -65,7 +65,8 @@ class TripletExtractor:
# Create messages for LLM # Create messages for LLM
messages = [ messages = [
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, {"role": "system",
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "user", "content": prompt_content} {"role": "user", "content": prompt_content}
] ]
@@ -116,7 +117,8 @@ class TripletExtractor:
logger.error(f"Error processing statement: {e}", exc_info=True) logger.error(f"Error processing statement: {e}", exc_info=True)
return TripletExtractionResponse(triplets=[], entities=[]) return TripletExtractionResponse(triplets=[], entities=[])
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]: async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
str, TripletExtractionResponse]:
"""Extract triplets and entities from statements """Extract triplets and entities from statements
Args: Args:

View File

@@ -2,15 +2,15 @@ import os
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any from typing import List, Dict, Any
# Setup Jinja2 environment # Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME", baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False, memory_verify: bool = False, quality_assessment: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets=None, language_type: str = "zh") -> str:
""" """
Renders the evaluate prompt using the evaluate_optimized.jinja2 template. Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
Returns: Returns:
Rendered prompt content as string Rendered prompt content as string
""" """
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("evaluate.jinja2") template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed # Convert Pydantic model to JSON schema if needed
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str: statement_databasets=None, language_type: str = "zh") -> str:
""" """
Renders the reflexion prompt using the reflexion_optimized.jinja2 template. Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
Returns: Returns:
Rendered prompt content as a string. Rendered prompt content as a string.
""" """
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("reflexion.jinja2") template = prompt_env.get_template("reflexion.jinja2")
# Convert Pydantic model to JSON schema if needed # Convert Pydantic model to JSON schema if needed
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
json_schema = schema json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema, rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify, baseline=baseline, memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type) statement_databasets=statement_databasets, language_type=language_type)
return rendered_prompt return rendered_prompt

View File

@@ -1,23 +1,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import os import os
import time from typing import Any, Dict, Optional, TypeVar
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TypeVar from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatTongyi
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI, OpenAI
from pydantic import BaseModel, Field
import httpx
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType from app.models.models_model import ModelProvider, ModelType
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableSerializable
from pydantic import BaseModel, Field
T = TypeVar("T") T = TypeVar("T")
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if type == ModelType.LLM: if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM return OllamaLLM
elif provider == ModelProvider.BEDROCK: elif provider == ModelProvider.BEDROCK:
from langchain_aws import ChatBedrock, ChatBedrockConverse
return ChatBedrock return ChatBedrock
else: else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput from app.schemas import FileInput
from app.schemas.model_schema import ModelInfo
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -620,11 +621,12 @@ class BaseNode(ABC):
@staticmethod @staticmethod
async def process_message( async def process_message(
provider: str, api_config: ModelInfo,
is_omni: bool,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider
if isinstance(content, dict): if isinstance(content, dict):
content = FileObject( content = FileObject(
type=content.get("type"), type=content.get("type"),
@@ -643,7 +645,7 @@ class BaseNode(ABC):
if content.content_cache.get(provider): if content.content_cache.get(provider):
return content.content_cache[provider] return content.content_cache[provider]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni) multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
type=content.type, type=content.type,
url=content.url, url=content.url,
@@ -653,7 +655,8 @@ class BaseNode(ABC):
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodel_service.process_files(
[file_obj] end_user_id,
[file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
if message: if message:

View File

@@ -5,7 +5,7 @@ from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
"output": VariableType.STRING "output": VariableType.STRING
} }
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
result = []
for case in self.typed_config.cases:
expressions = []
for expression in case.expressions:
expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
})
return {
"cases": result
}
@staticmethod @staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any: def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator: match operator:

View File

@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING "output": VariableType.ARRAY_STRING
} }
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
}
@staticmethod @staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
""" """

View File

@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType from app.models import ModelType
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
# 在 Session 关闭前提取所有需要的数据 # 在 Session 关闭前提取所有需要的数据
api_config = self.model_balance(config) api_config = self.model_balance(config)
model_name = api_config.model_name model_info = ModelInfo(
provider = api_config.provider model_name=api_config.model_name,
api_key = api_config.api_key model_type=ModelType(config.type),
api_base = api_config.api_base api_key=api_config.api_key,
is_omni = api_config.is_omni api_base=api_config.api_base,
model_type = config.type provider=api_config.provider,
is_omni=api_config.is_omni,
capability=api_config.capability
)
# 4. 创建 LLM 实例(使用已提取的数据) # 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True # 注意:对于流式输出,需要在模型初始化时设置 streaming=True
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
llm = RedBearLLM( llm = RedBearLLM(
RedBearModelConfig( RedBearModelConfig(
model_name=model_name, model_name=model_info.model_name,
provider=provider, provider=model_info.provider,
api_key=api_key, api_key=model_info.api_key,
base_url=api_base, base_url=model_info.api_base,
extra_params=extra_params, extra_params=extra_params,
is_omni=is_omni is_omni=model_info.is_omni
), ),
type=ModelType(model_type) type=model_info.model_type
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
"role": "system", "role": "system",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(
model_info,
content,
user_id,
self.typed_config.vision,
)
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision) content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
{"role": message["role"], "content": file_content} {"role": message["role"], "content": file_content}
) )
else: else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) message["content"] = await self.process_message(
model_info,
message["content"],
user_id,
self.typed_config.vision
)
history_message.append(message) history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:] messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages self.messages = messages
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
async for chunk in llm.astream(self.messages, stream_usage=True): async for chunk in llm.astream(self.messages):
# 提取内容 # 提取内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content) content = self.process_model_output(chunk.content)

View File

@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
} }
return None return None
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"text": self._render_template(self.typed_config.text, variable_pool),
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
"model_id": str(self.typed_config.model_id),
}
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {} outputs = {}
for param in self.typed_config.params: for param in self.typed_config.params:

View File

@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base from app.db import Base
from app.schemas import FileType
class PerceptualType(IntEnum): class PerceptualType(IntEnum):
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
TEXT = 3 TEXT = 3
CONVERSATION = 4 CONVERSATION = 4
@staticmethod
def trans_from_file_type(file_type: FileType | str):
type_map = {
FileType.IMAGE: PerceptualType.VISION,
FileType.AUDIO: PerceptualType.AUDIO,
FileType.VIDEO: PerceptualType.VISION,
FileType.DOCUMENT: PerceptualType.TEXT
}
return type_map.get(file_type, PerceptualType.TEXT)
class FileStorageService(IntEnum): class FileStorageService(IntEnum):
LOCAL = 1 LOCAL = 1

View File

@@ -2,7 +2,7 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from sqlalchemy import and_, desc from sqlalchemy import and_, desc, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}") db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
raise raise
def get_by_url(
self,
file_url: str
) -> list[MemoryPerceptualModel]:
try:
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
return list(self.db.execute(stmt).scalars())
except Exception:
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
raise
def get_by_type( def get_by_type(
self, self,
end_user_id: uuid.UUID, end_user_id: uuid.UUID,

View File

@@ -1,5 +1,4 @@
import uuid import uuid
from datetime import datetime
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
class Content(BaseModel): class Content(BaseModel):
summary: str
keywords: list[str] keywords: list[str]
topic: str topic: str
domain: str domain: str

View File

@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
is_official: Optional[bool] = Field(None, description="是否官方模型") is_official: Optional[bool] = Field(None, description="是否官方模型")
is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255) search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelInfo(BaseModel):
"""模型信息Schema"""
model_name: str = Field(..., description="模型名称")
provider: str = Field(..., description="模型提供商")
api_key: str = Field(..., description="API密钥")
api_base: str = Field(..., description="API基础URL")
is_omni: bool = Field(default=False, description="是否为omni模型")
model_type: ModelType = Field(..., description="模型类型")
capability: List[str] = Field(default_factory=list, description="模型能力列表")

View File

@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.db import get_db from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ from app.services.draft_run_service import AgentRunService
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
logger = get_business_logger() logger = get_business_logger()
@@ -126,8 +122,17 @@ class AppChatService:
# 处理多模态文件 # 处理多模态文件
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) model_info = ModelInfo(
processed_files = await multimodal_service.process_files(files) model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态 # 调用 Agent支持多模态
@@ -266,8 +271,17 @@ class AppChatService:
# 处理多模态文件 # 处理多模态文件
processed_files = None processed_files = None
if files: if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) model_info = ModelInfo(
processed_files = await multimodal_service.process_files(files) model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态 # 流式调用 Agent支持多模态

View File

@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service from app.services import task_service
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
@@ -501,9 +502,18 @@ class AgentRunService:
processed_files = None processed_files = None
if files: if files:
# 获取 provider 信息 # 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(files) processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -704,9 +714,18 @@ class AgentRunService:
processed_files = None processed_files = None
if files: if files:
# 获取 provider 信息 # 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(files) processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -841,7 +860,8 @@ class AgentRunService:
"api_key": api_key.api_key, "api_key": api_key.api_key,
"api_base": api_key.api_base, "api_base": api_key.api_base,
"api_key_id": api_key.id, "api_key_id": api_key.id,
"is_omni": api_key.is_omni "is_omni": api_key.is_omni,
"capability": api_key.capability
} }
async def _ensure_conversation( async def _ensure_conversation(

View File

@@ -274,7 +274,7 @@ class MemoryAgentService:
Args: Args:
end_user_id: Group identifier (also used as end_user_id) end_user_id: Group identifier (also used as end_user_id)
message: Message to write messages: Message to write
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag) storage_type: Storage type (neo4j or rag)

View File

@@ -1,19 +1,27 @@
import os
import uuid import uuid
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from urllib.parse import urlparse, unquote
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType
from app.schemas.memory_perceptual_schema import ( from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema, PerceptualQuerySchema,
PerceptualTimelineResponse, PerceptualTimelineResponse,
PerceptualMemoryItem, PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal AudioModal, Content, VideoModal, TextModal
) )
from app.schemas.model_schema import ModelInfo
business_logger = get_business_logger() business_logger = get_business_logger()
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
"keywords": content.keywords, "keywords": content.keywords,
"topic": content.topic, "topic": content.topic,
"domain": content.domain, "domain": content.domain,
"created_time": int(memory.created_time.timestamp()*1000), "created_time": int(memory.created_time.timestamp() * 1000),
**detail **detail
} }
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
return result return result
except Exception as e: except Exception as e:
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}") business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
exc_info=True)
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR) BizCode.DB_ERROR)
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
for memory in memories: for memory in memories:
meta_data = memory.meta_data or {} meta_data = memory.meta_data or {}
content = meta_data.get("content", {}) content = meta_data.get("content", {})
# 安全地提取 content 字段,提供默认值 # 安全地提取 content 字段,提供默认值
if content: if content:
content_obj = Content(**content) content_obj = Content(**content)
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
topic = "Unknown" topic = "Unknown"
domain = "Unknown" domain = "Unknown"
keywords = [] keywords = []
memory_item = PerceptualMemoryItem( memory_item = PerceptualMemoryItem(
id=memory.id, id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type), perceptual_type=PerceptualType(memory.perceptual_type),
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
topic=topic, topic=topic,
domain=domain, domain=domain,
keywords=keywords, keywords=keywords,
created_time=int(memory.created_time.timestamp()*1000), created_time=int(memory.created_time.timestamp() * 1000),
storage_service=FileStorageService(memory.storage_service), storage_service=FileStorageService(memory.storage_service),
) )
memory_items.append(memory_item) memory_items.append(memory_item)
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
except Exception as e: except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
async def generate_perceptual_memory(
self,
end_user_id: str,
model_config: ModelInfo,
file_type: str,
file_url: str,
file_message: dict,
):
memories = self.repository.get_by_url(file_url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return
llm = RedBearLLM(RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
), type=model_config.model_type)
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True)
path = urlparse(file_url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
if not file_ext:
if file_type == FileType.AUDIO:
file_ext = ".mp3"
elif file_type == FileType.VIDEO:
file_ext = ".mp4"
elif file_type == FileType.DOCUMENT:
file_ext = ".txt"
elif file_type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
"keywords": content.get("keywords", []),
"topic": content.get("topic"),
"domain": content.get("domain")
}
if file_type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene")
}
elif file_type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count"),
"title": content.get("title"),
"first_line": content.get("first_line")
}
else:
file_modalities = {
"speaker_count": content.get("speaker_count")
}
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type),
file_path=file_url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary'),
meta_data={
"content": file_content,
"modalities": file_modalities
}
)
self.db.commit()

View File

@@ -10,6 +10,7 @@
""" """
import base64 import base64
import io import io
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
@@ -23,9 +24,12 @@ from app.core.config import settings
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.models import ModelApiKey
from app.models.file_metadata_model import FileMetadata from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger() logger = get_business_logger()
@@ -39,6 +43,7 @@ DOC_MIME = [
class MultimodalFormatStrategy(ABC): class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类""" """多模态格式策略基类"""
def __init__(self, file: FileInput): def __init__(self, file: FileInput):
self.file = file self.file = file
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
if transcription: if transcription:
return { return {
"type": "text", "type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>" "text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
} }
# 通义千问音频格式:{"type": "audio", "audio": "url"} # 通义千问音频格式:{"type": "audio", "audio": "url"}
return { return {
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
class MultimodalService: class MultimodalService:
"""多模态文件处理服务""" """
Service for handling multimodal file processing.
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, Attributes:
enable_audio_transcription: bool = False, is_omni: bool = False): db (Session): Database session.
model_api_key (str): API key for the model provider.
provider (str): Name of the model provider.
is_omni (bool): Indicates whether the model supports full multimodal capability.
capability (list): Capability configuration of the model.
audio_api_key (str | None): API key used for audio transcription.
enable_audio_transcription (bool): Whether audio transcription is enabled.
"""
def __init__(
self,
db: Session,
api_config: ModelInfo | None = None,
audio_api_key: Optional[str] = None,
enable_audio_transcription: bool = False,
):
""" """
初始化多模态服务 Initialize the multimodal service.
Args: Args:
db: 数据库会话 db (Session): Database session.
provider: 模型提供商dashscope, bedrock, anthropic, openai 等) api_config (ModelApiKey | None): Model API configuration.
api_key: API 密钥(用于音频转文本) audio_api_key (str | None): API key for audio transcription.
enable_audio_transcription: 是否启用音频转文本 enable_audio_transcription (bool): Enable audio transcription.
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
""" """
self.db = db self.db = db
self.provider = provider.lower() self.api_config = api_config
self.api_key = api_key if self.api_config is not None:
self.model_api_key = api_config.api_key
self.provider = api_config.provider.lower()
self.is_omni = api_config.is_omni
self.capability = api_config.capability
self.audio_api_key = audio_api_key
self.enable_audio_transcription = enable_audio_transcription self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files( async def process_files(
self, self,
files: Optional[List[FileInput]] end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
处理文件列表,返回 LLM 可用的格式 处理文件列表,返回 LLM 可用的格式
Args: Args:
end_user_id: 用户ID
files: 文件输入列表 files: 文件输入列表
Returns: Returns:
@@ -319,6 +346,8 @@ class MultimodalService:
""" """
if not files: if not files:
return [] return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略 # 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式 # dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -333,19 +362,25 @@ class MultimodalService:
result = [] result = []
for idx, file in enumerate(files): for idx, file in enumerate(files):
strategy = strategy_class(file) strategy = strategy_class(file)
if not file.url:
file.url = await self.get_file_url(file)
try: try:
if file.type == FileType.IMAGE: if file.type == FileType.IMAGE and "vision" in self.capability:
content = await self._process_image(file, strategy) content = await self._process_image(file, strategy)
result.append(content) result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT: elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy) content = await self._process_document(file, strategy)
result.append(content) result.append(content)
elif file.type == FileType.AUDIO: self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
content = await self._process_audio(file, strategy) content = await self._process_audio(file, strategy)
result.append(content) result.append(content)
elif file.type == FileType.VIDEO: self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability:
content = await self._process_video(file, strategy) content = await self._process_video(file, strategy)
result.append(content) result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else: else:
logger.warning(f"不支持的文件类型: {file.type}") logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e: except Exception as e:
@@ -355,7 +390,8 @@ class MultimodalService:
"file_index": idx, "file_index": idx,
"file_type": file.type, "file_type": file.type,
"error": str(e) "error": str(e)
} },
exc_info=True
) )
# 继续处理其他文件,不中断整个流程 # 继续处理其他文件,不中断整个流程
result.append({ result.append({
@@ -366,6 +402,17 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}") logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
""" """
处理图片文件 处理图片文件
@@ -387,43 +434,6 @@ class MultimodalService:
"text": f"[图片处理失败: {str(e)}]" "text": f"[图片处理失败: {str(e)}]"
} }
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
Args:
url: 图片 URL
Returns:
tuple: (base64_data, media_type)
"""
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
""" """
处理文档文件PDF、Word 等) 处理文档文件PDF、Word 等)
@@ -436,7 +446,6 @@ class MultimodalService:
Dict: 根据 provider 返回不同格式的文档内容 Dict: 根据 provider 返回不同格式的文档内容
""" """
if file.transfer_method == TransferMethod.REMOTE_URL: if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
return { return {
"type": "text", "type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>" "text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
@@ -471,12 +480,12 @@ class MultimodalService:
# 如果启用音频转文本且有 API Key # 如果启用音频转文本且有 API Key
transcription = None transcription = None
if self.enable_audio_transcription and self.api_key: if self.enable_audio_transcription and self.audio_api_key:
logger.info(f"开始音频转文本: {url}") logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope": if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
elif self.provider == "openai": elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
else: else:
logger.warning(f"Provider {self.provider} 不支持音频转文本") logger.warning(f"Provider {self.provider} 不支持音频转文本")

View File

@@ -0,0 +1,53 @@
{% raw %}You are a professional information extraction system.
Your task is to analyze the provided document content and generate structured metadata.
Extract the following fields:
* **summary**: A concise summary of the document in 24 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the document expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
STRICT RULES:
1. Output MUST be valid JSON.
2. Do NOT output markdown.
3. Do NOT output explanations.
4. Do NOT output any text before or after the JSON.
5. The JSON MUST contain EXACTLY these four keys:
* summary
* keywords
* topic
* domain{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
{% if file_type == 'audio' %} * speaker_count {% endif %}
{% if file_type == 'document' %} * section_count
* title
* first_line
{% endif %}
{% raw %}
6. `keywords` MUST be a JSON array of strings.
7. If the document content is insufficient, infer the best possible answer based on context.
8. Ensure the JSON is syntactically correct.
{% endraw %}
9. Output using the language {{ language }}
{% raw %}
Required JSON format:
{
"summary": "string",
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
"topic": "string",
"domain": "string",
{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
{% if file_type == 'document' %} "section_count": integer
"title": "string",
"first_line": "string"
{% endif %}
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
{% raw %}
}
Now analyze the following document and return the JSON result.{% endraw %}

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import json import hashlib
import logging
import os import os
import re import re
import shutil import shutil
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID
import redis import redis
import requests
from redis.exceptions import RedisError from redis.exceptions import RedisError
logger = logging.getLogger(__name__) # Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.rag.crawler.web_crawler import WebCrawler
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.feishu.models import FileInfo
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.integrations.yuque.models import YuqueDocInfo
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock
logger = get_logger(__name__)
# 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 模块级同步 Redis 连接池,供 Celery 任务共享使用
# 连接 CELERY_BACKEND DB与 write_message:last_done 时间戳写入保持一致 # 连接 CELERY_BACKEND DB与 write_message:last_done 时间戳写入保持一致
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连
_sync_redis_pool: redis.ConnectionPool = None _sync_redis_pool: redis.ConnectionPool | None = None
def _get_or_create_redis_pool() -> redis.ConnectionPool:
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
"""获取或创建 Redis 连接池(懒初始化)""" """获取或创建 Redis 连接池(懒初始化)"""
global _sync_redis_pool global _sync_redis_pool
if _sync_redis_pool is None: if _sync_redis_pool is None:
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
return None return None
return _sync_redis_pool return _sync_redis_pool
def get_sync_redis_client() -> Optional[redis.StrictRedis]: def get_sync_redis_client() -> Optional[redis.StrictRedis]:
"""获取同步 Redis 客户端(使用连接池) """获取同步 Redis 客户端(使用连接池)
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
pool = _get_or_create_redis_pool() pool = _get_or_create_redis_pool()
if pool is None: if pool is None:
return None return None
client = redis.StrictRedis(connection_pool=pool) client = redis.StrictRedis(connection_pool=pool)
# 验证连接可用性 # 验证连接可用性
client.ping() client.ping()
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True) logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
return None return None
# Import a unified Celery instance
from app.celery_app import celery_app def set_asyncio_event_loop():
from app.core.config import settings """Set the asyncio event loop for the current thread."""
from app.core.rag.crawler.web_crawler import WebCrawler try:
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb loop = asyncio.get_event_loop()
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache if loop.is_closed():
from app.core.rag.integrations.feishu.client import FeishuAPIClient loop = asyncio.new_event_loop()
from app.core.rag.integrations.feishu.models import FileInfo asyncio.set_event_loop(loop)
from app.core.rag.integrations.yuque.client import YuqueAPIClient except RuntimeError:
from app.core.rag.integrations.yuque.models import YuqueDocInfo loop = asyncio.new_event_loop()
from app.core.rag.llm.chat_model import Base asyncio.set_event_loop(loop)
from app.core.rag.llm.cv_model import QWenCV return loop
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.file_model import File
from app.models.knowledge_model import Knowledge
from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService
from app.utils.config_utils import resolve_config_id
@celery_app.task(name="tasks.process_item") @celery_app.task(name="tasks.process_item")
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
vector_size = len(vts[0]) vector_size = len(vts[0])
init_graphrag(task, vector_size) init_graphrag(task, vector_size)
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, async def _run(
chat_model, embedding_model, callback, with_resolution: bool = True, row: dict,
with_community: bool = True, ) -> dict: document_ids: list[str],
language: str,
parser_config: dict,
vector_service,
chat_model,
embedding_model,
callback,
with_resolution: bool = True,
with_community: bool = True
) -> dict:
await trio.sleep(5) # Delay for 10 seconds await trio.sleep(5) # Delay for 10 seconds
nonlocal progress_msg # Declare the use of an external progress_msg variable nonlocal progress_msg # Declare the use of an external progress_msg variable
result = await run_graphrag_for_kb( result = await run_graphrag_for_kb(
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community, with_community=with_community,
) )
) )
try: try:
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task) future = executor.submit(sync_task)
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
with_community=with_community, with_community=with_community,
) )
) )
try: try:
with ThreadPoolExecutor(max_workers=1) as executor: with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task) future = executor.submit(sync_task)
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
# Log but continue - will fail later with proper error # Log but continue - will fail later with proper error
pass pass
async def _run() -> str: async def _run() -> dict:
with get_db_context() as db: with get_db_context() as db:
service = MemoryAgentService() service = MemoryAgentService()
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, return await service.read_memory(
storage_type, user_rag_memory_id) end_user_id,
message,
history,
search_switch,
actual_config_id, db,
storage_type, user_rag_memory_id
)
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True) @celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str, def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
user_rag_memory_id: str,
language: str = "zh") -> Dict[str, Any]: language: str = "zh") -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
Raises: Raises:
Exception on failure Exception on failure
""" """
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}") logger.info(
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
f"config_id={config_id} (type: {type(config_id).__name__}), "
f"storage_type={storage_type}, language={language}")
start_time = time.time() start_time = time.time()
# Convert config_id to UUID # Convert config_id to UUID
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try: try:
with get_db_context() as db: with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db) actual_config_id = resolve_config_id(config_id, db)
print(100*'-') print(100 * '-')
print(actual_config_id) print(actual_config_id)
print(100*'-') print(100 * '-')
logger.info( logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e: except (ValueError, AttributeError) as e:
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") logger.error(
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}", "error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
async def _run() -> str: async def _run() -> str:
with get_db_context() as db: with get_db_context() as db:
logger.info( logger.info(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService() service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
user_rag_memory_id, language) user_rag_memory_id, language)
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
return result return result
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
} }
def reflection_engine() -> None:
"""Empty function placeholder for timed background reflection.
Intentionally left blank; replace with real reflection logic later.
"""
import asyncio
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))
@celery_app.task(name="app.core.memory.agent.reflection.timer")
def reflection_timer_task() -> None:
"""Periodic Celery task that invokes reflection_engine.
Raises an exception on failure.
"""
reflection_engine()
# unused task # unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service") # @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]: # def check_read_service_task() -> Dict[str, str]:
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"workspace_id": workspace_id, "workspace_id": workspace_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
} }
@celery_app.task( @celery_app.task(
name="app.tasks.write_all_workspaces_memory_task", name="app.tasks.write_all_workspaces_memory_task",
bind=True, bind=True,
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.app_model import App from app.models.app_model import App
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all from app.services.memory_storage_service import search_all
api_logger = get_api_logger()
with get_db_context() as db: with get_db_context() as db:
try: try:
# 获取所有活跃的工作空间 # 获取所有活跃的工作空间
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
).all() ).all()
if not workspaces: if not workspaces:
api_logger.warning("没有找到活跃的工作空间") logger.warning("没有找到活跃的工作空间")
return { return {
"status": "SUCCESS", "status": "SUCCESS",
"message": "没有找到活跃的工作空间", "message": "没有找到活跃的工作空间",
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"workspace_results": [] "workspace_results": []
} }
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
all_workspace_results = [] all_workspace_results = []
# 遍历每个工作空间 # 遍历每个工作空间
for workspace in workspaces: for workspace in workspaces:
workspace_id = workspace.id workspace_id = workspace.id
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
try: try:
# 1. 查询当前workspace下的所有app仅未删除的 # 1. 查询当前workspace下的所有app仅未删除的
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"memory_increment_id": str(memory_increment.id), "memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(), "created_at": memory_increment.created_at.isoformat(),
}) })
api_logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0") logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0")
continue continue
# 2. 查询所有app下的end_user_id去重 # 2. 查询所有app下的end_user_id去重
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
}) })
except Exception as e: except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户 # 记录单个用户查询失败,但继续处理其他用户
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({ end_user_details.append({
"end_user_id": str(end_user_id), "end_user_id": str(end_user_id),
"total": 0, "total": 0,
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"created_at": memory_increment.created_at.isoformat(), "created_at": memory_increment.created_at.isoformat(),
}) })
api_logger.info( logger.info(
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}" f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
) )
except Exception as e: except Exception as e:
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间 db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
all_workspace_results.append({ all_workspace_results.append({
"workspace_id": str(workspace_id), "workspace_id": str(workspace_id),
"workspace_name": workspace.name, "workspace_name": workspace.name,
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
} }
except Exception as e: except Exception as e:
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}") logger.error(f"记忆增量统计任务执行失败: {str(e)}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": str(e), "error": str(e),
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.services.user_memory_service import UserMemoryService from app.services.user_memory_service import UserMemoryService
logger = get_logger(__name__)
logger.info("开始执行记忆缓存重新生成定时任务") logger.info("开始执行记忆缓存重新生成定时任务")
service = UserMemoryService() service = UserMemoryService()
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace from app.models.workspace_model import Workspace
from app.services.memory_reflection_service import ( from app.services.memory_reflection_service import (
MemoryReflectionService, MemoryReflectionService,
WorkspaceAppService, WorkspaceAppService,
) )
api_logger = get_api_logger()
with get_db_context() as db: with get_db_context() as db:
try: try:
# 获取所有工作空间 # 获取所有工作空间
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
# 遍历每个工作空间 # 遍历每个工作空间
for workspace in workspaces: for workspace in workspaces:
workspace_id = workspace.id workspace_id = workspace.id
api_logger.info(f"开始处理工作空间反思workspace_id: {workspace_id}") logger.info(f"开始处理工作空间反思workspace_id: {workspace_id}")
try: try:
reflection_service = MemoryReflectionService(db) reflection_service = MemoryReflectionService(db)
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
workspace_reflection_results = [] workspace_reflection_results = []
for data in result['apps_detailed_info']: for data in result['apps_detailed_info']:
if data['memory_configs'] == []: if not data['memory_configs']:
continue continue
releases = data['releases'] releases = data['releases']
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str( if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
user['app_id']): user['app_id']):
# 调用反思服务 # 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}") logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data( reflection_result = await reflection_service.start_reflection_from_data(
config_data=config, config_data=config,
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
"reflection_results": workspace_reflection_results "reflection_results": workspace_reflection_results
}) })
api_logger.info( logger.info(
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e: except Exception as e:
db.rollback() # Rollback failed transaction to allow next query db.rollback() # Rollback failed transaction to allow next query
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({ all_reflection_results.append({
"workspace_id": str(workspace_id), "workspace_id": str(workspace_id),
"error": str(e), "error": str(e),
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
} }
except Exception as e: except Exception as e:
api_logger.error(f"工作空间反思任务执行失败: {str(e)}") logger.error(f"工作空间反思任务执行失败: {str(e)}")
return { return {
"status": "FAILURE", "status": "FAILURE",
"error": str(e), "error": str(e),
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
api_logger = get_api_logger()
with get_db_context() as db: with get_db_context() as db:
try: try:
api_logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}") logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}")
forget_service = MemoryForgetService() forget_service = MemoryForgetService()
# 运行遗忘周期 # 运行遗忘周期
# FIXME: MemeoryForgetService
report = await forget_service.trigger_forgetting( report = await forget_service.trigger_forgetting(
db=db, db=db,
end_user_id=None, # 处理所有组 end_user_id=None, # 处理所有组
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
duration = time.time() - start_time duration = time.time() - start_time
api_logger.info( logger.info(
f"遗忘周期定时任务完成: " f"遗忘周期定时任务完成: "
f"融合 {report['merged_count']} 对节点, " f"融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, " f"失败 {report['failed_count']} 对, "
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
except Exception as e: except Exception as e:
duration = time.time() - start_time duration = time.time() - start_time
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
return { return {
"status": "FAILED", "status": "FAILED",
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
finally: finally:
loop.close() loop.close()
# ============================================================================= # =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies) # Long-term Memory Storage Tasks (Batched Write Strategies)
# ============================================================================= # =============================================================================
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from sqlalchemy import func, select from sqlalchemy import select
from app.core.logging_config import get_logger
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.repositories.implicit_emotions_storage_repository import ( from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository, ImplicitEmotionsStorageRepository,
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务") logger.info("开始执行隐性记忆和情绪数据更新定时任务")
total_users = 0 total_users = 0
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
for end_user_id in refresh_iter: for end_user_id in refresh_iter:
logger.info(f"开始处理用户: {end_user_id}") logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time() user_start_time = time.time()
implicit_success = False implicit_success = False
emotion_success = False emotion_success = False
errors = [] errors = []
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
failed += 1 failed += 1
user_elapsed = time.time() - user_start_time user_elapsed = time.time() - user_start_time
# 记录用户处理结果 # 记录用户处理结果
user_result = { user_result = {
"end_user_id": end_user_id, "end_user_id": end_user_id,
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
} }
try: try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的 # 尝试获取现有事件循环,如果不存在则创建新的
try: loop = set_asyncio_event_loop()
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import ( from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository, ImplicitEmotionsStorageRepository,
) )
from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__)
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}") logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
initialized = 0 initialized = 0
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
} }
try: try:
try: loop = set_asyncio_event_loop()
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
默认生成中文zh兴趣分布数据。 默认生成中文zh兴趣分布数据。
Args: Args:
self: task object
end_user_ids: 需要检查的用户ID列表 end_user_ids: 需要检查的用户ID列表
Returns: Returns:
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
start_time = time.time() start_time = time.time()
async def _run() -> Dict[str, Any]: async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__)
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}") logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0 initialized = 0
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
} }
try: try:
try: loop = set_asyncio_event_loop()
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time result["elapsed_time"] = time.time() - start_time
@@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
"elapsed_time": time.time() - start_time, "elapsed_time": time.time() - start_time,
"task_id": self.request.id, "task_id": self.request.id,
} }
@celery_app.task(
name="app.tasks.write_perceptual_memory",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def write_perceptual_memory(
self,
end_user_id: str,
model_api_config: dict,
file_type: str,
file_url: str,
file_message: dict
):
"""
Write perceptual memory for a user into PostgreSQL and Neo4j.
This task generates or updates the user's perceptual memory
in the backend databases. It is intended to be executed asynchronously
via Celery.
Args:
end_user_id (uuid.UUID): The unique identifier of the end user.
model_api_config (ModelInfo): API configuration for the model
used to generate perceptual memory.
file_type (str): The file type
file_url (url): The url of file
file_message (dict): The file message containing details about the file
to be processed.
Returns:
None
"""
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
set_asyncio_event_loop()
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
model_info = ModelInfo(**model_api_config)
with get_db_context() as db:
memory_perceptual_service = MemoryPerceptualService(db)
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
end_user_id,
model_info,
file_type,
file_url,
file_message,
))

View File

@@ -0,0 +1,61 @@
import redis
import uuid
import time
UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
class RedisLock:
def __init__(
self,
key: str,
redis_client: redis.StrictRedis,
expire: int = 60,
retry_interval: float = 0.1,
timeout: float = 30
):
self.key = key
self.expire = expire
self.value = str(uuid.uuid4())
self._locked = False
self.retry_interval = retry_interval
self.timeout = timeout
self.redis_client = redis_client
def acquire(self) -> bool:
start = time.time()
while True:
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
if ok:
self._locked = True
return True
if time.time() - start >= self.timeout:
return False
time.sleep(self.retry_interval)
def release(self):
if not self._locked:
return
self.redis_client.eval(
UNLOCK_SCRIPT,
1,
self.key,
self.value
)
self._locked = False
def __enter__(self):
ok = self.acquire()
if not ok:
raise RuntimeError(f"Get redis lock timeout: {self.key}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()