Merge pull request #551 from SuanmoSuanyangTechnology/feature/multimodel_file
feat(multimodel): support multimodal memory display and improve code style
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
|
||||
@@ -62,7 +62,7 @@ celery_app.conf.update(
|
||||
task_serializer='json',
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=False,
|
||||
@@ -70,43 +70,44 @@ celery_app.conf.update(
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
@@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab(
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@@ -808,6 +808,15 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -853,6 +862,8 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -864,7 +875,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
@@ -895,7 +906,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
@@ -4,13 +4,13 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
@@ -19,19 +19,20 @@ def content_input_node(state: ReadState) -> ReadState:
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
return {"data": content}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,32 +2,32 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
@@ -38,14 +38,24 @@ async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
Write memory with structured message support
|
||||
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
|
||||
Args:
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
@@ -55,7 +65,7 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
@@ -64,8 +74,9 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
@@ -105,17 +116,16 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
@@ -126,13 +136,12 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -142,22 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope=scope
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -174,42 +184,53 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""Time-based memory processing"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
"""Aggregation judgment processing"""
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
@@ -217,11 +238,11 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
|
||||
history = None
|
||||
try:
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
@@ -255,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -268,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,26 +2,25 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
@@ -61,10 +60,10 @@ def extract_tool_message_content(response):
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
@@ -72,25 +71,26 @@ class TimeRetrievalInput(BaseModel):
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
@@ -102,10 +102,10 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
@@ -126,15 +126,16 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
@@ -142,10 +143,11 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
@@ -167,18 +169,19 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
@@ -186,10 +189,11 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
@@ -212,29 +216,29 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
@@ -247,11 +251,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
@@ -265,7 +269,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
else:
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -279,7 +283,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
) -> str:
|
||||
"""
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
@@ -292,7 +296,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
@@ -329,9 +333,9 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -340,35 +344,36 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
@@ -378,10 +383,11 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
@@ -392,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
Format and parse message lists into different output types
|
||||
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
@@ -19,8 +21,8 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -30,37 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -72,20 +75,20 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -42,13 +41,15 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
@@ -56,9 +57,10 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
@@ -66,25 +68,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type==AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type==AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type==AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
@@ -95,7 +96,7 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
@@ -128,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -620,11 +621,12 @@ class BaseNode(ABC):
|
||||
|
||||
@staticmethod
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
@@ -643,7 +645,7 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
@@ -653,7 +655,8 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
result = []
|
||||
for case in self.typed_config.cases:
|
||||
expressions = []
|
||||
for expression in case.expressions:
|
||||
expressions.append({
|
||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||
"right": expression.right
|
||||
if expression.input_type == ValueInputType.CONSTANT
|
||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||
"operator": expression.operator,
|
||||
})
|
||||
result.append({
|
||||
"expressions": expressions,
|
||||
"logical_operator": case.logical_operator,
|
||||
})
|
||||
return {
|
||||
"cases": result
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
|
||||
@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
model_info = ModelInfo(
|
||||
model_name=api_config.model_name,
|
||||
model_type=ModelType(config.type),
|
||||
api_key=api_config.api_key,
|
||||
api_base=api_config.api_base,
|
||||
provider=api_config.provider,
|
||||
is_omni=api_config.is_omni,
|
||||
capability=api_config.capability
|
||||
)
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
model_name=model_info.model_name,
|
||||
provider=model_info.provider,
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
type=model_info.model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
logger.debug(
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
|
||||
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"text": self._render_template(self.typed_config.text, variable_pool),
|
||||
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
|
||||
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
|
||||
"model_id": str(self.typed_config.model_id),
|
||||
}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.db import Base
|
||||
from app.schemas import FileType
|
||||
|
||||
|
||||
class PerceptualType(IntEnum):
|
||||
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
|
||||
TEXT = 3
|
||||
CONVERSATION = 4
|
||||
|
||||
@staticmethod
|
||||
def trans_from_file_type(file_type: FileType | str):
|
||||
type_map = {
|
||||
FileType.IMAGE: PerceptualType.VISION,
|
||||
FileType.AUDIO: PerceptualType.AUDIO,
|
||||
FileType.VIDEO: PerceptualType.VISION,
|
||||
FileType.DOCUMENT: PerceptualType.TEXT
|
||||
}
|
||||
return type_map.get(file_type, PerceptualType.TEXT)
|
||||
|
||||
|
||||
class FileStorageService(IntEnum):
|
||||
LOCAL = 1
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
|
||||
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_url(
|
||||
self,
|
||||
file_url: str
|
||||
) -> list[MemoryPerceptualModel]:
|
||||
try:
|
||||
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
except Exception:
|
||||
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
|
||||
raise
|
||||
|
||||
def get_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
|
||||
@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
|
||||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
provider: str = Field(..., description="模型提供商")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
api_base: str = Field(..., description="API基础URL")
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -126,8 +122,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -266,8 +271,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
|
||||
@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -501,9 +502,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -704,9 +714,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -841,7 +860,8 @@ class AgentRunService:
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
"is_omni": api_key.is_omni,
|
||||
"capability": api_key.capability
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
|
||||
@@ -274,7 +274,7 @@ class MemoryAgentService:
|
||||
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
PerceptualMemoryItem,
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
|
||||
"keywords": content.keywords,
|
||||
"topic": content.topic,
|
||||
"domain": content.domain,
|
||||
"created_time": int(memory.created_time.timestamp()*1000),
|
||||
"created_time": int(memory.created_time.timestamp() * 1000),
|
||||
**detail
|
||||
}
|
||||
|
||||
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
exc_info=True)
|
||||
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
BizCode.DB_ERROR)
|
||||
|
||||
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
|
||||
for memory in memories:
|
||||
meta_data = memory.meta_data or {}
|
||||
content = meta_data.get("content", {})
|
||||
|
||||
|
||||
# 安全地提取 content 字段,提供默认值
|
||||
if content:
|
||||
content_obj = Content(**content)
|
||||
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
|
||||
topic = "Unknown"
|
||||
domain = "Unknown"
|
||||
keywords = []
|
||||
|
||||
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
|
||||
topic=topic,
|
||||
domain=domain,
|
||||
keywords=keywords,
|
||||
created_time=int(memory.created_time.timestamp()*1000),
|
||||
created_time=int(memory.created_time.timestamp() * 1000),
|
||||
storage_service=FileStorageService(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
"keywords": content.get("keywords", []),
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene")
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count"),
|
||||
"title": content.get("title"),
|
||||
"first_line": content.get("first_line")
|
||||
}
|
||||
else:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count")
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary'),
|
||||
meta_data={
|
||||
"content": file_content,
|
||||
"modalities": file_modalities
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
@@ -23,9 +24,12 @@ from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import ModelApiKey
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +43,7 @@ DOC_MIME = [
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
"""
|
||||
Service for handling multimodal file processing.
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
Attributes:
|
||||
db (Session): Database session.
|
||||
model_api_key (str): API key for the model provider.
|
||||
provider (str): Name of the model provider.
|
||||
is_omni (bool): Indicates whether the model supports full multimodal capability.
|
||||
capability (list): Capability configuration of the model.
|
||||
audio_api_key (str | None): API key used for audio transcription.
|
||||
enable_audio_transcription (bool): Whether audio transcription is enabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_config: ModelInfo | None = None,
|
||||
audio_api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Initialize the multimodal service.
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
db (Session): Database session.
|
||||
api_config (ModelApiKey | None): Model API configuration.
|
||||
audio_api_key (str | None): API key for audio transcription.
|
||||
enable_audio_transcription (bool): Enable audio transcription.
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.api_config = api_config
|
||||
if self.api_config is not None:
|
||||
self.model_api_key = api_config.api_key
|
||||
self.provider = api_config.provider.lower()
|
||||
self.is_omni = api_config.is_omni
|
||||
self.capability = api_config.capability
|
||||
self.audio_api_key = audio_api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -319,6 +346,8 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -333,19 +362,25 @@ class MultimodalService:
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -355,7 +390,8 @@ class MultimodalService:
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
@@ -366,6 +402,17 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -387,43 +434,6 @@ class MultimodalService:
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
Args:
|
||||
url: 图片 URL
|
||||
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -436,7 +446,6 @@ class MultimodalService:
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
@@ -471,12 +480,12 @@ class MultimodalService:
|
||||
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
if self.enable_audio_transcription and self.audio_api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
|
||||
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
@@ -0,0 +1,53 @@
|
||||
{% raw %}You are a professional information extraction system.
|
||||
|
||||
Your task is to analyze the provided document content and generate structured metadata.
|
||||
|
||||
Extract the following fields:
|
||||
|
||||
* **summary**: A concise summary of the document in 2–4 sentences.
|
||||
* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
|
||||
* **topic**: The primary topic of the document expressed as a short phrase (3–8 words).
|
||||
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
|
||||
|
||||
STRICT RULES:
|
||||
|
||||
1. Output MUST be valid JSON.
|
||||
2. Do NOT output markdown.
|
||||
3. Do NOT output explanations.
|
||||
4. Do NOT output any text before or after the JSON.
|
||||
5. The JSON MUST contain EXACTLY these four keys:
|
||||
* summary
|
||||
* keywords
|
||||
* topic
|
||||
* domain{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
|
||||
{% if file_type == 'audio' %} * speaker_count {% endif %}
|
||||
{% if file_type == 'document' %} * section_count
|
||||
* title
|
||||
* first_line
|
||||
{% endif %}
|
||||
{% raw %}
|
||||
6. `keywords` MUST be a JSON array of strings.
|
||||
7. If the document content is insufficient, infer the best possible answer based on context.
|
||||
8. Ensure the JSON is syntactically correct.
|
||||
{% endraw %}
|
||||
9. Output using the language {{ language }}
|
||||
{% raw %}
|
||||
Required JSON format:
|
||||
|
||||
{
|
||||
"summary": "string",
|
||||
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
|
||||
"topic": "string",
|
||||
"domain": "string",
|
||||
{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
|
||||
{% if file_type == 'document' %} "section_count": integer
|
||||
"title": "string",
|
||||
"first_line": "string"
|
||||
{% endif %}
|
||||
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
|
||||
{% raw %}
|
||||
}
|
||||
|
||||
Now analyze the following document and return the JSON result.{% endraw %}
|
||||
392
api/app/tasks.py
392
api/app/tasks.py
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import Document, File, Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
from app.utils.redis_lock import RedisLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
_sync_redis_pool: redis.ConnectionPool = None
|
||||
_sync_redis_pool: redis.ConnectionPool | None = None
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
|
||||
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||
global _sync_redis_pool
|
||||
if _sync_redis_pool is None:
|
||||
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
return None
|
||||
return _sync_redis_pool
|
||||
|
||||
|
||||
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
"""获取同步 Redis 客户端(使用连接池)
|
||||
|
||||
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
pool = _get_or_create_redis_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
|
||||
client = redis.StrictRedis(connection_pool=pool)
|
||||
# 验证连接可用性
|
||||
client.ping()
|
||||
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
def set_asyncio_event_loop():
|
||||
"""Set the asyncio event loop for the current thread."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
|
||||
chat_model, embedding_model, callback, with_resolution: bool = True,
|
||||
with_community: bool = True, ) -> dict:
|
||||
async def _run(
|
||||
row: dict,
|
||||
document_ids: list[str],
|
||||
language: str,
|
||||
parser_config: dict,
|
||||
vector_service,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True
|
||||
) -> dict:
|
||||
await trio.sleep(5) # Delay for 10 seconds
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
result = await run_graphrag_for_kb(
|
||||
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
async def _run() -> dict:
|
||||
with get_db_context() as db:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||
storage_type, user_rag_memory_id)
|
||||
return await service.read_memory(
|
||||
end_user_id,
|
||||
message,
|
||||
history,
|
||||
search_switch,
|
||||
actual_config_id, db,
|
||||
storage_type, user_rag_memory_id
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str,
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh") -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}")
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||
f"storage_type={storage_type}, language={language}")
|
||||
start_time = time.time()
|
||||
|
||||
# Convert config_id to UUID
|
||||
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(config_id, db)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
print(actual_config_id)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
logger.error(
|
||||
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
||||
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
async def _run() -> str:
|
||||
with get_db_context() as db:
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||
user_rag_memory_id, language)
|
||||
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
return result
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
}
|
||||
|
||||
|
||||
def reflection_engine() -> None:
|
||||
"""Empty function placeholder for timed background reflection.
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||
def reflection_timer_task() -> None:
|
||||
"""Periodic Celery task that invokes reflection_engine.
|
||||
|
||||
Raises an exception on failure.
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
# unused task
|
||||
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
# def check_read_service_task() -> Dict[str, str]:
|
||||
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||
|
||||
service = UserMemoryService()
|
||||
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
if not data['memory_configs']:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
|
||||
user['app_id']):
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
# 运行遗忘周期
|
||||
# FIXME: MemeoryForgetService
|
||||
report = await forget_service.trigger_forgetting(
|
||||
db=db,
|
||||
end_user_id=None, # 处理所有组
|
||||
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"遗忘周期定时任务完成: "
|
||||
f"融合 {report['merged_count']} 对节点, "
|
||||
f"失败 {report['failed_count']} 对, "
|
||||
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILED",
|
||||
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
for end_user_id in refresh_iter:
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
self: task object
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
61
api/app/utils/redis_lock.py
Normal file
61
api/app/utils/redis_lock.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import redis
|
||||
import uuid
|
||||
import time
|
||||
|
||||
UNLOCK_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
redis_client: redis.StrictRedis,
|
||||
expire: int = 60,
|
||||
retry_interval: float = 0.1,
|
||||
timeout: float = 30
|
||||
|
||||
):
|
||||
self.key = key
|
||||
self.expire = expire
|
||||
self.value = str(uuid.uuid4())
|
||||
self._locked = False
|
||||
self.retry_interval = retry_interval
|
||||
self.timeout = timeout
|
||||
self.redis_client = redis_client
|
||||
|
||||
def acquire(self) -> bool:
|
||||
start = time.time()
|
||||
while True:
|
||||
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
|
||||
if ok:
|
||||
self._locked = True
|
||||
return True
|
||||
if time.time() - start >= self.timeout:
|
||||
return False
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def release(self):
|
||||
if not self._locked:
|
||||
return
|
||||
self.redis_client.eval(
|
||||
UNLOCK_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value
|
||||
)
|
||||
self._locked = False
|
||||
|
||||
def __enter__(self):
|
||||
ok = self.acquire()
|
||||
if not ok:
|
||||
raise RuntimeError(f"Get redis lock timeout: {self.key}")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.release()
|
||||
Reference in New Issue
Block a user