From b71bc1f875f1b50d5fca46da0dabf1be7b3c1e5c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 13 Mar 2026 13:33:58 +0800 Subject: [PATCH] feat(multimodel): support multimodal memory display and improve code style --- api/app/aioRedis.py | 3 +- api/app/celery_app.py | 23 +- api/app/celery_worker.py | 2 +- api/app/controllers/app_controller.py | 15 +- .../agent/langgraph_graph/nodes/data_nodes.py | 15 +- .../agent/langgraph_graph/routing/routers.py | 13 +- .../langgraph_graph/routing/write_router.py | 115 ++--- .../agent/langgraph_graph/tools/tool.py | 108 ++--- .../agent/langgraph_graph/tools/write_tool.py | 45 +- .../agent/langgraph_graph/write_graph.py | 37 +- api/app/core/memory/agent/utils/llm_tools.py | 16 +- .../triplet_extraction.py | 18 +- .../memory/utils/prompt/template_render.py | 16 +- api/app/core/models/base.py | 32 +- api/app/core/workflow/nodes/base_node.py | 11 +- api/app/core/workflow/nodes/if_else/node.py | 22 +- api/app/core/workflow/nodes/knowledge/node.py | 6 + api/app/core/workflow/nodes/llm/node.py | 59 ++- .../nodes/parameter_extractor/node.py | 8 + api/app/models/memory_perceptual_model.py | 11 + .../memory_perceptual_repository.py | 13 +- api/app/schemas/memory_perceptual_schema.py | 2 - api/app/schemas/model_schema.py | 11 + api/app/services/app_chat_service.py | 36 +- api/app/services/draft_run_service.py | 32 +- api/app/services/memory_agent_service.py | 2 +- api/app/services/memory_perceptual_service.py | 114 ++++- api/app/services/multimodal_service.py | 129 +++--- .../prompt/perceptual_summary_system.jinja2 | 53 +++ api/app/tasks.py | 392 ++++++++---------- api/app/utils/redis_lock.py | 61 +++ 31 files changed, 877 insertions(+), 543 deletions(-) create mode 100644 api/app/services/prompt/perceptual_summary_system.jinja2 create mode 100644 api/app/utils/redis_lock.py diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index f758dd15..aac2aa84 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,10 +1,11 @@ -import os import asyncio import json import logging from typing import Dict, Any, Optional + import redis.asyncio as redis from redis.asyncio import ConnectionPool + from app.core.config import settings # 设置日志记录器 diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 21ee291d..60c22855 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -62,7 +62,7 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - + # 时区 timezone='Asia/Shanghai', enable_utc=False, @@ -70,43 +70,44 @@ celery_app.conf.update( # 任务追踪 task_track_started=True, task_ignore_result=False, - + # 超时设置 task_time_limit=3600, # 60分钟硬超时 task_soft_time_limit=3000, # 50分钟软超时 - + # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution - + # 结果过期时间 result_expires=3600, # 结果保存1小时 - + # 任务确认设置 task_acks_late=True, task_reject_on_worker_lost=True, worker_disable_rate_limits=True, - + # FLower setting worker_send_task_events=True, task_send_sent_event=True, - + # task routing task_routes={ # Memory tasks → memory_tasks queue (threads worker) 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, - + 'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, - + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, - + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, @@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab( minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE, ) -#构建定时任务配置 +# 构建定时任务配置 beat_schedule_config = { "run-workspace-reflection": { "task": "app.tasks.workspace_reflection_task", diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py index 7d3ee686..4ea4fee1 100644 --- a/api/app/celery_worker.py +++ b/api/app/celery_worker.py @@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized") # 导入任务模块以注册任务 import app.tasks -__all__ = ['celery_app'] \ No newline at end of file +__all__ = ['celery_app'] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 5d1551d8..31451a7d 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -808,6 +808,15 @@ async def draft_run_compare( raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) service._validate_app_accessible(app, workspace_id) + if payload.user_id is None: + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, + other_id=str(current_user.id), + original_user_id=str(current_user.id) # Save original user_id to other_id + ) + payload.user_id = str(new_end_user.id) + # 2. 获取 Agent 配置 from sqlalchemy import select from app.models import AgentConfig @@ -853,6 +862,8 @@ async def draft_run_compare( "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id }) + + # 流式返回 if payload.stream: async def event_generator(): @@ -864,7 +875,7 @@ async def draft_run_compare( message=payload.message, workspace_id=workspace_id, conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), + user_id=payload.user_id, variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, @@ -895,7 +906,7 @@ async def draft_run_compare( message=payload.message, workspace_id=workspace_id, conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), + user_id=payload.user_id, variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py index 829f26c4..ca08db76 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -4,13 +4,13 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState def content_input_node(state: ReadState) -> ReadState: """ Start node - Extract content and maintain state information - + Extracts the content from the first message in the state and returns it as the data field while preserving all other state information. - + Args: state: ReadState containing messages and other state data - + Returns: ReadState: Updated state with extracted content in data field """ @@ -19,19 +19,20 @@ def content_input_node(state: ReadState) -> ReadState: # Return content and maintain all state information return {"data": content} + def content_input_write(state: WriteState) -> WriteState: """ Start node - Extract content and maintain state information for write operations - + Extracts the content from the first message in the state for write operations. - + Args: state: WriteState containing messages and other state data - + Returns: WriteState: Updated state with extracted content in data field """ content = state['messages'][0].content if state.get('messages') else '' # Return content and maintain all state information - return {"data": content} \ No newline at end of file + return {"data": content} diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index 004e03b3..d6ca3333 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -1,13 +1,13 @@ - from typing import Literal from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState - logger = get_agent_logger(__name__) counter = COUNTState(limit=3) -def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: + + +def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: """ Determine routing based on search_switch value. @@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa return 'Input_Summary' return 'Split_The_Problem' # 默认情况 + def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: """ Determine routing based on search_switch value. @@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: elif search_switch == '1': return 'Retrieve_Summary' return 'Retrieve_Summary' # Default based on business logic + + def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - status=state.get('verify', '')['status'] + status = state.get('verify', '')['status'] # loop_count = counter.get_total() if "success" in status: # counter.reset() @@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co # if loop_count < 2: # Maximum loop count is 3 # return "content_input" # else: - # counter.reset() + # counter.reset() return "Summary_fails" else: # Add default return value to avoid returning None diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index ddb6ca3e..6176caf5 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -2,32 +2,32 @@ import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage - +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store +from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context, get_db +from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') + async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): """ Write messages to RAG storage system - + Combines user and AI messages into a single string format and stores them in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. - + Args: end_user_id: User identifier for the conversation user_message: User's input message content @@ -38,14 +38,24 @@ async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory combined_message = f"user: {user_message}\nassistant: {ai_message}" await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') -async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, - actual_config_id, long_term_messages=[]): + + +async def write( + storage_type, + end_user_id, + user_message, + ai_message, + user_rag_memory_id, + actual_end_user_id, + actual_config_id, + long_term_messages=None +): """ Write memory with structured message support - + Handles memory writing operations for different storage types (Neo4j/RAG). Supports both individual message pairs and batch long-term message processing. - + Args: storage_type: Storage type identifier ("neo4j" or "rag") end_user_id: Terminal user identifier @@ -55,7 +65,7 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me actual_end_user_id: Actual user identifier for storage actual_config_id: Configuration identifier long_term_messages: Optional list of structured messages for batch processing - + Logic explanation: - RAG mode: Combines user_message and ai_message into string format, maintains original logic - Neo4j mode: Uses structured message lists @@ -64,8 +74,9 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me 3. Each message is converted to independent Chunk, preserving speaker field """ - db = next(get_db()) - try: + if long_term_messages is None: + long_term_messages = [] + with get_db_context() as db: actual_config_id = resolve_config_id(actual_config_id, db) # Neo4j mode: Use structured message lists structured_messages = [] @@ -105,17 +116,16 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() -async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + +async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): """ Save long-term memory data to database - + Handles the storage of long-term memory data based on different strategies (chunk-based or aggregate-based) and manages the transition from short-term to long-term memory storage. - + Args: long_term_messages: Long-term message data to be saved actual_config_id: Configuration identifier for memory settings @@ -126,13 +136,12 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type, with get_db_context() as db_session: repo = LongTermMemoryRepository(db_session) - from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: data = await format_parsing(result, "dict") chunk_data = data[:scope] - if len(chunk_data)==scope: + if len(chunk_data) == scope: repo.upsert(end_user_id, chunk_data) logger.info(f'---------写入短长期-----------') else: @@ -142,22 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type, logger.info(f'写入短长期:') - """Window-based dialogue processing""" -async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): + + +async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): """ Process dialogue based on window size and write to Neo4j - + Manages conversation data based on a sliding window approach. When the window reaches the specified scope size, it triggers long-term memory storage to Neo4j. - + Args: end_user_id: Terminal user identifier memory_config: Memory configuration object containing settings langchain_messages: Original message data list scope: Window size determining when to trigger long-term storage """ - scope=scope + scope = scope is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] @@ -174,42 +184,53 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): config_id = memory_config.config_id else: config_id = memory_config - - await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, - config_id, formatted_messages) + + await write( + AgentMemory_Long_Term.STORAGE_NEO4J, + end_user_id, + "", + "", + None, + end_user_id, + config_id, + formatted_messages + ) count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) """Time-based memory processing""" -async def memory_long_term_storage(end_user_id,memory_config,time): + + +async def memory_long_term_storage(end_user_id, memory_config, time): """ Process memory storage based on time intervals and write to Neo4j - + Retrieves Redis data based on time intervals and writes it to Neo4j for long-term storage. This function handles time-based memory consolidation. - + Args: end_user_id: Terminal user identifier memory_config: Memory configuration object containing settings time: Time interval for data retrieval """ long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - format_messages = (long_time_data) - messages=[] - memory_config=memory_config.config_id + format_messages = long_time_data + messages = [] + memory_config = memory_config.config_id for i in format_messages: - message=json.loads(i['Query']) - messages+= message - if format_messages!=[]: + message = json.loads(i['Query']) + messages += message + if format_messages: await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, memory_config, messages) -"""Aggregation judgment processing""" + + async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ Aggregation judgment function: determine if input sentence and historical messages describe the same event - + Uses LLM-based analysis to determine whether new messages should be aggregated with existing historical data or stored as separate events. This helps optimize memory storage and retrieval. @@ -217,11 +238,11 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config end_user_id: Terminal user identifier ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: Memory configuration object containing LLM settings - + Returns: dict: Aggregation judgment result containing is_same_event flag and processed output """ - + history = None try: # 1. Get historical session data (using new method) result = write_store.get_all_sessions_by_end_user_id(end_user_id) @@ -255,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config output_value = structured.output if isinstance(output_value, list): output_value = [ - {"role": msg.role, "content": msg.content} + {"role": msg.role, "content": msg.content} for msg in output_value ] @@ -268,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config await write("neo4j", end_user_id, "", "", None, end_user_id, memory_config.config_id, output_value) return result_dict - + except Exception as e: print(f"[aggregate_judgment] 发生错误: {e}") import traceback traceback.print_exc() - + return { "is_same_event": False, "output": ori_messages, "messages": ori_messages, "history": history if 'history' in locals() else [], "error": str(e) - } \ No newline at end of file + } diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index bee77ddf..9bd2b2cf 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -2,26 +2,25 @@ import asyncio import json from datetime import datetime, timedelta - from langchain.tools import tool from pydantic import BaseModel, Field - from app.core.memory.src.search import ( search_by_temporal, search_by_keyword_temporal, ) + def extract_tool_message_content(response): """ Extract ToolMessage content and tool names from agent response - + Parses agent response messages to extract tool execution results and metadata. Handles JSON parsing and provides structured access to tool output data. - + Args: response: Agent response dictionary containing messages - + Returns: dict: Dictionary containing tool_name and parsed content, or None if no tool message found - tool_name: Name of the executed tool @@ -61,10 +60,10 @@ def extract_tool_message_content(response): class TimeRetrievalInput(BaseModel): """ Input schema for time retrieval tool - + Defines the expected input parameters for time-based retrieval operations. Used for validation and documentation of tool parameters. - + Attributes: context: User input query content for search end_user_id: Group ID for filtering search results, defaults to test user @@ -72,25 +71,26 @@ class TimeRetrievalInput(BaseModel): context: str = Field(description="用户输入的查询内容") end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + def create_time_retrieval_tool(end_user_id: str): """ Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range - + Creates a specialized time-based retrieval tool that searches for statements within specified time ranges. Includes field cleaning functionality to remove unnecessary metadata from search results. - + Args: end_user_id: User identifier for scoping search results - + Returns: function: Configured TimeRetrievalWithGroupId tool function """ - + def clean_temporal_result_fields(data): """ Clean unnecessary fields from temporal search results and modify structure - + Removes metadata fields that are not needed for end-user consumption and restructures the response format for better usability. @@ -102,10 +102,10 @@ def create_time_retrieval_tool(end_user_id: str): """ # List of fields to filter out fields_to_remove = { - 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', + 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', 'valid_at', 'invalid_at', 'statement_ids' } - + if isinstance(data, dict): cleaned = {} for key, value in data.items(): @@ -126,15 +126,16 @@ def create_time_retrieval_tool(end_user_id: str): return [clean_temporal_result_fields(item) for item in data] else: return data - + @tool - def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, + end_user_id_param: str = None, clean_output: bool = True) -> str: """ Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields - + Performs time-based search operations with automatic metadata filtering. Supports flexible date range specification and provides clean, user-friendly output. - + Explicit parameters: - context: Query context content - start_date: Start time (optional, format: YYYY-MM-DD) @@ -142,10 +143,11 @@ def create_time_retrieval_tool(end_user_id: str): - end_user_id_param: Group ID (optional, overrides default group ID) - clean_output: Whether to clean metadata fields from output - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") - + Returns: str: JSON formatted search results with temporal data """ + async def _async_search(): # Use passed parameters or default values actual_end_user_id = end_user_id_param or end_user_id @@ -167,18 +169,19 @@ def create_time_retrieval_tool(end_user_id: str): cleaned_results = results return json.dumps(cleaned_results, ensure_ascii=False, indent=2) - + return asyncio.run(_async_search()) @tool - def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: + def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, + clean_output: bool = True) -> str: """ Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields - + Performs combined keyword and temporal search operations with automatic metadata filtering. Provides more targeted search results by combining content relevance with time-based filtering. - + Explicit parameters: - context: Query content for keyword matching - days_back: Number of days to search backwards, default 7 days @@ -186,10 +189,11 @@ def create_time_retrieval_tool(end_user_id: str): - end_date: End time (optional, format: YYYY-MM-DD) - clean_output: Whether to clean metadata fields from output - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") - + Returns: str: JSON formatted search results combining keyword and temporal data """ + async def _async_search(): actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") @@ -212,29 +216,29 @@ def create_time_retrieval_tool(end_user_id: str): return json.dumps(cleaned_results, ensure_ascii=False, indent=2) return asyncio.run(_async_search()) - + return TimeRetrievalWithGroupId def create_hybrid_retrieval_tool_async(memory_config, **search_params): """ Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields - + Creates an advanced hybrid search tool that combines multiple search strategies (keyword, vector, hybrid) with automatic result cleaning and formatting. Args: memory_config: Memory configuration object containing LLM and search settings **search_params: Search parameters including end_user_id, limit, include, etc. - + Returns: function: Configured HybridSearch tool function with async capabilities """ - + def clean_result_fields(data): """ Recursively clean unnecessary fields from results - + Removes metadata fields that are not needed for end-user consumption, improving readability and reducing response size. @@ -247,11 +251,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): # List of fields to filter out # TODO: fact_summary functionality temporarily disabled, will be enabled after future development fields_to_remove = { - 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', - 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary" } - + if isinstance(data, dict): # Clean dictionary cleaned = {} @@ -265,7 +269,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): else: # Return other types directly return data - + @tool async def HybridSearch( context: str, @@ -279,7 +283,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): ) -> str: """ Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields - + Provides comprehensive search capabilities combining multiple search strategies with intelligent result ranking and automatic metadata filtering for clean output. @@ -292,7 +296,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): use_forgetting_rerank: Whether to use forgetting-based reranking use_llm_rerank: Whether to use LLM-based reranking clean_output: Whether to clean metadata fields from output - + Returns: str: JSON formatted comprehensive search results """ @@ -329,9 +333,9 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): "search_type": search_type, "results": cleaned_results } - + return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) - + except Exception as e: error_result = { "error": f"混合检索失败: {str(e)}", @@ -340,35 +344,36 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): "timestamp": datetime.now().isoformat() } return json.dumps(error_result, ensure_ascii=False, indent=2) - + return HybridSearch def create_hybrid_retrieval_tool_sync(memory_config, **search_params): """ Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields - + Creates a synchronous wrapper around the async hybrid search functionality, making it compatible with synchronous tool execution environments. Args: memory_config: Memory configuration object containing search settings **search_params: Search parameters for configuration - + Returns: function: Configured HybridSearchSync tool function """ + @tool def HybridSearchSync( - context: str, - search_type: str = "hybrid", - limit: int = 10, - end_user_id: str = None, - clean_output: bool = True + context: str, + search_type: str = "hybrid", + limit: int = 10, + end_user_id: str = None, + clean_output: bool = True ) -> str: """ Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields - + Provides the same hybrid search capabilities as the async version but in a synchronous execution context. Automatically handles async-to-sync conversion. @@ -378,10 +383,11 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): limit: Result quantity limit end_user_id: Group ID for filtering search results clean_output: Whether to clean metadata fields from output - + Returns: str: JSON formatted search results """ + async def _async_search(): # Create async tool and execute async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) @@ -392,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): "end_user_id": end_user_id, "clean_output": clean_output }) - + return asyncio.run(_async_search()) - - return HybridSearchSync \ No newline at end of file + + return HybridSearchSync diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index aa5e09a6..e11a2085 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -1,10 +1,12 @@ import json from langchain_core.messages import HumanMessage, AIMessage -async def format_parsing(messages: list,type:str='string'): + + +async def format_parsing(messages: list, type: str = 'string'): """ Format and parse message lists into different output types - + Processes message lists from storage and converts them into either string format or dictionary format based on the specified type parameter. Handles JSON parsing and role-based message organization. @@ -19,8 +21,8 @@ async def format_parsing(messages: list,type:str='string'): - 'dict': List of dictionaries mapping user messages to AI responses """ result = [] - user=[] - ai=[] + user = [] + ai = [] for message in messages: hstory_messages = message['messages'] @@ -30,37 +32,38 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human' or role=="user": + if role == 'human' or role == "user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict" : - if role == 'human' or role=="user": - user.append( content) + if type == "dict": + if role == 'human' or role == "user": + user.append(content) else: ai.append(content) if type == "dict": - for key,values in zip(user,ai): - result.append({key:values}) + for key, values in zip(user, ai): + result.append({key: values}) return result + async def messages_parse(messages: list | dict): """ Parse messages from storage format into user-AI conversation pairs - + Extracts and organizes conversation data from stored message format, separating user and AI messages and pairing them for database storage. - + Args: messages: List or dictionary containing stored message data with Query fields - + Returns: list: List of dictionaries containing user-AI message pairs for database storage """ - user=[] - ai=[] - database=[] + user = [] + ai = [] + database = [] for message in messages: Query = message['Query'] Query = json.loads(Query) @@ -72,20 +75,20 @@ async def messages_parse(messages: list | dict): ai.append(data['content']) for key, values in zip(user, ai): database.append({key, values}) - return database + return database -async def agent_chat_messages(user_content,ai_content): +async def agent_chat_messages(user_content, ai_content): """ Create structured chat message format for agent conversations - + Formats user and AI content into a standardized message structure suitable for agent processing and storage. Creates role-based message objects. - + Args: user_content: User's message content string ai_content: AI's response content string - + Returns: list: List of structured message dictionaries with role and content fields """ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 15009955..bf3c6597 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService - warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -42,13 +41,15 @@ async def make_write_graph(): yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + +async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', + end_user_id: str = '', scope: int = 6): """ Handle long-term memory storage with different strategies - - Supports multiple storage strategies including chunk-based, time-based, + + Supports multiple storage strategies including chunk-based, time-based, and aggregate judgment approaches for long-term memory persistence. - + Args: long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') langchain_messages: List of messages to store @@ -56,9 +57,10 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ end_user_id: User group identifier scope: Scope parameter for chunk-based storage (default: 6) """ - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, (langchain_messages)) + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) @@ -66,25 +68,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ config_id=memory_config, # 改为整数 service_name="MemoryAgentService" ) - if long_term_type==AgentMemory_Long_Term.STRATEGY_CHUNK: + if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: '''Strategy 1: Dialogue window with 6 rounds of conversation''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type==AgentMemory_Long_Term.STRATEGY_TIME: + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: """Time-based strategy""" - await memory_long_term_storage(end_user_id, memory_config,AgentMemory_Long_Term.TIME_SCOPE) - if long_term_type==AgentMemory_Long_Term.STRATEGY_AGGREGATE: + await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) + if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: """Strategy 3: Aggregate judgment""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) - -async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): +async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): """ Write long-term memory with different storage types - + Handles both RAG-based storage and traditional memory storage approaches. For traditional storage, uses chunk-based strategy with paired user-AI messages. - + Args: storage_type: Type of storage (RAG or traditional) end_user_id: User group identifier @@ -95,7 +96,7 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_ """ from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) else: @@ -128,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_ # # if __name__ == "__main__": # import asyncio -# asyncio.run(main()) \ No newline at end of file +# asyncio.run(main()) diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index 1c183422..ea8add48 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -8,10 +8,11 @@ from langgraph.graph import add_messages PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3]) + class WriteState(TypedDict): - ''' + """ Langgrapg Writing TypedDict - ''' + """ messages: Annotated[list[AnyMessage], add_messages] end_user_id: str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] @@ -20,6 +21,7 @@ class WriteState(TypedDict): data: str language: str # 语言类型 ("zh" 中文, "en" 英文) + class ReadState(TypedDict): """ LangGraph 工作流状态定义 @@ -43,18 +45,20 @@ class ReadState(TypedDict): config_id: str data: str # 新增字段用于传递内容 spit_data: dict # 新增字段用于传递问题分解结果 - problem_extension:dict + problem_extension: dict storage_type: str user_rag_memory_id: str llm_id: str embedding_id: str memory_config: object # 新增字段用于传递内存配置对象 - retrieve:dict + retrieve: dict RetrieveSummary: dict InputSummary: dict verify: dict SummaryFails: dict summary: dict + + class COUNTState: """ 工作流对话检索内容计数器 @@ -99,6 +103,7 @@ class COUNTState: self.total = 0 print("[COUNTState] 已重置为 0") + def deduplicate_entries(entries): seen = set() deduped = [] @@ -109,6 +114,7 @@ def deduplicate_entries(entries): deduped.append(entry) return deduped + def merge_to_key_value_pairs(data, query_key, result_key): grouped = defaultdict(list) for item in data: @@ -142,4 +148,4 @@ def convert_extended_question_to_question(data): return [convert_extended_question_to_question(item) for item in data] else: # 其他类型直接返回 - return data \ No newline at end of file + return data diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index 024e320a..147ed777 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -5,7 +5,7 @@ from typing import List, Dict, Optional from app.core.logging_config import get_memory_logger from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt -from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 +from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 from app.core.memory.models.triplet_models import TripletExtractionResponse from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger logger = get_memory_logger(__name__) - class TripletExtractor: """Extracts knowledge triplets and entities from statements using LLM""" def __init__( - self, - llm_client: OpenAIClient, - ontology_types: Optional[OntologyTypeList] = None, - language: str = "zh"): + self, + llm_client: OpenAIClient, + ontology_types: Optional[OntologyTypeList] = None, + language: str = "zh" + ): """Initialize the TripletExtractor with an LLM client Args: @@ -65,7 +65,8 @@ class TripletExtractor: # Create messages for LLM messages = [ - {"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, + {"role": "system", + "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, {"role": "user", "content": prompt_content} ] @@ -116,7 +117,8 @@ class TripletExtractor: logger.error(f"Error processing statement: {e}", exc_info=True) return TripletExtractionResponse(triplets=[], entities=[]) - async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]: + async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[ + str, TripletExtractionResponse]: """Extract triplets and entities from statements Args: diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 68e0ffe4..4df8d55b 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -2,15 +2,15 @@ import os from jinja2 import Environment, FileSystemLoader from typing import List, Dict, Any - # Setup Jinja2 environment prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) + async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, baseline: str = "TIME", - memory_verify: bool = False,quality_assessment:bool = False, - statement_databasets: List[str] = [],language_type:str = "zh") -> str: + memory_verify: bool = False, quality_assessment: bool = False, + statement_databasets=None, language_type: str = "zh") -> str: """ Renders the evaluate prompt using the evaluate_optimized.jinja2 template. @@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, Returns: Rendered prompt content as string """ + if statement_databasets is None: + statement_databasets = [] template = prompt_env.get_template("evaluate.jinja2") # Convert Pydantic model to JSON schema if needed @@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, - statement_databasets: List[str] = [],language_type:str = "zh") -> str: + statement_databasets=None, language_type: str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. @@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s Returns: Rendered prompt content as a string. """ + if statement_databasets is None: + statement_databasets = [] template = prompt_env.get_template("reflexion.jinja2") # Convert Pydantic model to JSON schema if needed @@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s json_schema = schema rendered_prompt = template.render(data=data, json_schema=json_schema, - baseline=baseline,memory_verify=memory_verify, - statement_databasets=statement_databasets,language_type=language_type) + baseline=baseline, memory_verify=memory_verify, + statement_databasets=statement_databasets, language_type=language_type) return rendered_prompt diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index dba6717d..4a453c6b 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -1,23 +1,19 @@ from __future__ import annotations -import asyncio import os -import time -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Dict, Optional, TypeVar + +from langchain_aws import ChatBedrock +from langchain_community.chat_models import ChatTongyi +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLLM +from langchain_ollama import OllamaLLM +from langchain_openai import ChatOpenAI, OpenAI +from pydantic import BaseModel, Field -import httpx from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.models.models_model import ModelProvider, ModelType -from langchain_community.document_compressors import JinaRerank -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.embeddings import Embeddings -from langchain_core.language_models import BaseLanguageModel, BaseLLM -from langchain_core.outputs import Generation, LLMResult -from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables import RunnableSerializable -from pydantic import BaseModel, Field T = TypeVar("T") @@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: - from langchain_openai import ChatOpenAI return ChatOpenAI - - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if type == ModelType.LLM: - from langchain_openai import OpenAI return OpenAI elif type == ModelType.CHAT: - from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: - from langchain_community.chat_models import ChatTongyi return ChatTongyi elif provider == ModelProvider.OLLAMA: - from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: - from langchain_aws import ChatBedrock, ChatBedrockConverse - return ChatBedrock else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 39c7887b..0e3fecee 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject from app.db import get_db_read from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.schemas import FileInput +from app.schemas.model_schema import ModelInfo from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -620,11 +621,12 @@ class BaseNode(ABC): @staticmethod async def process_message( - provider: str, - is_omni: bool, + api_config: ModelInfo, content: str | dict | FileObject, + end_user_id: str, enable_file=False ) -> list | str | None: + provider = api_config.provider if isinstance(content, dict): content = FileObject( type=content.get("type"), @@ -643,7 +645,7 @@ class BaseNode(ABC): if content.content_cache.get(provider): return content.content_cache[provider] with get_db_read() as db: - multimodel_service = MultimodalService(db, provider, is_omni=is_omni) + multimodel_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( type=content.type, url=content.url, @@ -653,7 +655,8 @@ class BaseNode(ABC): ) file_obj.set_content(content.get_content()) message = await multimodel_service.process_files( - [file_obj] + end_user_id, + [file_obj], ) content.set_content(file_obj.get_content()) if message: diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 29f7085b..7e98efab 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -5,7 +5,7 @@ from typing import Any from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.variable.base_variable import VariableType @@ -23,6 +23,26 @@ class IfElseNode(BaseNode): "output": VariableType.STRING } + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + result = [] + for case in self.typed_config.cases: + expressions = [] + for expression in case.expressions: + expressions.append({ + "left": self.get_variable(expression.left, variable_pool, strict=False), + "right": expression.right + if expression.input_type == ValueInputType.CONSTANT + else self.get_variable(expression.right, variable_pool, strict=False), + "operator": expression.operator, + }) + result.append({ + "expressions": expressions, + "logical_operator": case.logical_operator, + }) + return { + "cases": result + } + @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: match operator: diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 696298eb..14f789a9 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode): "output": VariableType.ARRAY_STRING } + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return { + "query": self._render_template(self.typed_config.query, variable_pool), + "knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases], + } + @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): """ diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 186c204f..b293d1f4 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_context from app.models import ModelType +from app.schemas.model_schema import ModelInfo from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) @@ -113,12 +114,15 @@ class LLMNode(BaseNode): # 在 Session 关闭前提取所有需要的数据 api_config = self.model_balance(config) - model_name = api_config.model_name - provider = api_config.provider - api_key = api_config.api_key - api_base = api_config.api_base - is_omni = api_config.is_omni - model_type = config.type + model_info = ModelInfo( + model_name=api_config.model_name, + model_type=ModelType(config.type), + api_key=api_config.api_key, + api_base=api_config.api_base, + provider=api_config.provider, + is_omni=api_config.is_omni, + capability=api_config.capability + ) # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True @@ -126,17 +130,18 @@ class LLMNode(BaseNode): llm = RedBearLLM( RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, + model_name=model_info.model_name, + provider=model_info.provider, + api_key=model_info.api_key, + base_url=model_info.api_base, extra_params=extra_params, - is_omni=is_omni + is_omni=model_info.is_omni ), - type=ModelType(model_type) + type=model_info.model_type ) - logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") + logger.debug( + f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages @@ -148,35 +153,40 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - + user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ "role": "system", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message( + model_info, + content, + user_id, + self.typed_config.vision, + ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) + content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -190,14 +200,19 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(provider, is_omni, file, self.typed_config.vision) + content = await self.process_message(model_info, file, user_id, self.typed_config.vision) if content: file_content.extend(content) history_message.append( {"role": message["role"], "content": file_content} ) else: - message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) + message["content"] = await self.process_message( + model_info, + message["content"], + user_id, + self.typed_config.vision + ) history_message.append(message) messages = messages[:-1] + history_message + messages[-1:] self.messages = messages @@ -293,7 +308,7 @@ class LLMNode(BaseNode): # 调用 LLM(流式,支持字符串或消息列表) last_meta_data = {} - async for chunk in llm.astream(self.messages, stream_usage=True): + async for chunk in llm.astream(self.messages): # 提取内容 if hasattr(chunk, 'content'): content = self.process_model_output(chunk.content) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 700ed85f..acac09e4 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode): } return None + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return { + "text": self._render_template(self.typed_config.text, variable_pool), + "prompt": self._render_template(self.typed_config.prompt, variable_pool), + "params": [param.model_dump(mode="json") for param in self.typed_config.params], + "model_id": str(self.typed_config.model_id), + } + def _output_types(self) -> dict[str, VariableType]: outputs = {} for param in self.typed_config.params: diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py index cafb18d4..9fed7c5d 100644 --- a/api/app/models/memory_perceptual_model.py +++ b/api/app/models/memory_perceptual_model.py @@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB from app.db import Base +from app.schemas import FileType class PerceptualType(IntEnum): @@ -15,6 +16,16 @@ class PerceptualType(IntEnum): TEXT = 3 CONVERSATION = 4 + @staticmethod + def trans_from_file_type(file_type: FileType | str): + type_map = { + FileType.IMAGE: PerceptualType.VISION, + FileType.AUDIO: PerceptualType.AUDIO, + FileType.VIDEO: PerceptualType.VISION, + FileType.DOCUMENT: PerceptualType.TEXT + } + return type_map.get(file_type, PerceptualType.TEXT) + class FileStorageService(IntEnum): LOCAL = 1 diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py index 9fa9536e..9077af03 100644 --- a/api/app/repositories/memory_perceptual_repository.py +++ b/api/app/repositories/memory_perceptual_repository.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime from typing import List, Tuple, Optional -from sqlalchemy import and_, desc +from sqlalchemy import and_, desc, select from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger @@ -127,6 +127,17 @@ class MemoryPerceptualRepository: db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}") raise + def get_by_url( + self, + file_url: str + ) -> list[MemoryPerceptualModel]: + try: + stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url) + return list(self.db.execute(stmt).scalars()) + except Exception: + db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}") + raise + def get_by_type( self, end_user_id: uuid.UUID, diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py index 7dfefe01..c9b741ef 100644 --- a/api/app/schemas/memory_perceptual_schema.py +++ b/api/app/schemas/memory_perceptual_schema.py @@ -1,5 +1,4 @@ import uuid -from datetime import datetime from typing import Optional from pydantic import BaseModel, Field @@ -85,7 +84,6 @@ class Semantic(BaseModel): class Content(BaseModel): - summary: str keywords: list[str] topic: str domain: str diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 4f3878ce..058f082d 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel): is_official: Optional[bool] = Field(None, description="是否官方模型") is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + +class ModelInfo(BaseModel): + """模型信息Schema""" + model_name: str = Field(..., description="模型名称") + provider: str = Field(..., description="模型提供商") + api_key: str = Field(..., description="API密钥") + api_base: str = Field(..., description="API基础URL") + is_omni: bool = Field(default=False, description="是否为omni模型") + model_type: ModelType = Field(..., description="模型类型") + capability: List[str] = Field(default_factory=list, description="模型能力列表") + diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index f3cdde2a..9b2b2a77 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List from fastapi import Depends from sqlalchemy.orm import Session -from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db -from app.models import MultiAgentConfig, AgentConfig +from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig from app.repositories.tool_repository import ToolRepository from app.schemas import DraftRunRequest from app.schemas.app_schema import FileInput +from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ - AgentRunService -from app.services.draft_run_service import create_web_search_tool +from app.services.draft_run_service import AgentRunService from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService -from app.services.tool_service import ToolService from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -126,8 +122,17 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) - processed_files = await multimodal_service.process_files(files) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -266,8 +271,17 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) - processed_files = await multimodal_service.process_files(files) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index d7914db5..b3b136a1 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig +from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput +from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service from app.services.conversation_service import ConversationService @@ -501,9 +502,18 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=ModelType.LLM + ) provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) - processed_files = await multimodal_service.process_files(files) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -704,9 +714,18 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=ModelType.LLM + ) provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) - processed_files = await multimodal_service.process_files(files) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -841,7 +860,8 @@ class AgentRunService: "api_key": api_key.api_key, "api_base": api_key.api_base, "api_key_id": api_key.id, - "is_omni": api_key.is_omni + "is_omni": api_key.is_omni, + "capability": api_key.capability } async def _ensure_conversation( diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index a20b968a..1e1d9e45 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -274,7 +274,7 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) - message: Message to write + messages: Message to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index b9d96a0b..53d935fe 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -1,19 +1,27 @@ +import os import uuid from typing import Dict, Any, Optional +from urllib.parse import urlparse, unquote +import json_repair +from jinja2 import Template from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig from app.models.memory_perceptual_model import PerceptualType, FileStorageService +from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository +from app.schemas import FileType from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, PerceptualMemoryItem, AudioModal, Content, VideoModal, TextModal ) +from app.schemas.model_schema import ModelInfo business_logger = get_business_logger() @@ -99,7 +107,7 @@ class MemoryPerceptualService: "keywords": content.keywords, "topic": content.topic, "domain": content.domain, - "created_time": int(memory.created_time.timestamp()*1000), + "created_time": int(memory.created_time.timestamp() * 1000), **detail } @@ -108,7 +116,8 @@ class MemoryPerceptualService: return result except Exception as e: - business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}") + business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", + exc_info=True) raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", BizCode.DB_ERROR) @@ -138,7 +147,7 @@ class MemoryPerceptualService: for memory in memories: meta_data = memory.meta_data or {} content = meta_data.get("content", {}) - + # 安全地提取 content 字段,提供默认值 if content: content_obj = Content(**content) @@ -149,7 +158,7 @@ class MemoryPerceptualService: topic = "Unknown" domain = "Unknown" keywords = [] - + memory_item = PerceptualMemoryItem( id=memory.id, perceptual_type=PerceptualType(memory.perceptual_type), @@ -161,7 +170,7 @@ class MemoryPerceptualService: topic=topic, domain=domain, keywords=keywords, - created_time=int(memory.created_time.timestamp()*1000), + created_time=int(memory.created_time.timestamp() * 1000), storage_service=FileStorageService(memory.storage_service), ) memory_items.append(memory_item) @@ -183,3 +192,98 @@ class MemoryPerceptualService: except Exception as e: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + + async def generate_perceptual_memory( + self, + end_user_id: str, + model_config: ModelInfo, + file_type: str, + file_url: str, + file_message: dict, + ): + memories = self.repository.get_by_url(file_url) + if memories: + business_logger.info(f"Perceptual memory already exists: {file_url}") + if end_user_id not in [memory.end_user_id for memory in memories]: + business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") + memory_cache = memories[0] + self.repository.create_perceptual_memory( + end_user_id=uuid.UUID(end_user_id), + perceptual_type=PerceptualType(memory_cache.perceptual_type), + file_path=memory_cache.file_path, + file_name=memory_cache.file_name, + file_ext=memory_cache.file_ext, + summary=memory_cache.summary, + meta_data=memory_cache.meta_data + ) + self.db.commit() + + return + llm = RedBearLLM(RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ), type=model_config.model_type) + try: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: + opt_system_prompt = f.read() + rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + except FileNotFoundError: + raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + messages = [ + {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, + {"role": RoleType.USER.value, "content": [ + {"type": "text", "text": "Summarize the following file"}, file_message + ]} + ] + result = await llm.ainvoke(messages) + content = json_repair.repair_json(result.content, return_objects=True) + path = urlparse(file_url).path + filename = os.path.basename(path) + filename = unquote(filename) + file_ext = os.path.splitext(filename)[1] + if not file_ext: + if file_type == FileType.AUDIO: + file_ext = ".mp3" + elif file_type == FileType.VIDEO: + file_ext = ".mp4" + elif file_type == FileType.DOCUMENT: + file_ext = ".txt" + elif file_type == FileType.IMAGE: + file_ext = ".jpg" + filename += file_ext + file_content = { + "keywords": content.get("keywords", []), + "topic": content.get("topic"), + "domain": content.get("domain") + } + if file_type in [FileType.IMAGE, FileType.VIDEO]: + file_modalities = { + "scene": content.get("scene") + } + elif file_type in [FileType.DOCUMENT]: + file_modalities = { + "section_count": content.get("section_count"), + "title": content.get("title"), + "first_line": content.get("first_line") + } + else: + file_modalities = { + "speaker_count": content.get("speaker_count") + } + self.repository.create_perceptual_memory( + end_user_id=uuid.UUID(end_user_id), + perceptual_type=PerceptualType.trans_from_file_type(file_type), + file_path=file_url, + file_name=filename, + file_ext=file_ext, + summary=content.get('summary'), + meta_data={ + "content": file_content, + "modalities": file_modalities + } + ) + self.db.commit() diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index fffca2e5..935efafe 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -10,6 +10,7 @@ """ import base64 import io +import uuid from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -23,9 +24,12 @@ from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.models import ModelApiKey from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod +from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService +from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -39,6 +43,7 @@ DOC_MIME = [ class MultimodalFormatStrategy(ABC): """多模态格式策略基类""" + def __init__(self, file: FileInput): self.file = file @@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): if transcription: return { "type": "text", - "text": f"" + "text": f"" } # 通义千问音频格式:{"type": "audio", "audio": "url"} return { @@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = { class MultimodalService: - """多模态文件处理服务""" + """ + Service for handling multimodal file processing. - def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, - enable_audio_transcription: bool = False, is_omni: bool = False): + Attributes: + db (Session): Database session. + model_api_key (str): API key for the model provider. + provider (str): Name of the model provider. + is_omni (bool): Indicates whether the model supports full multimodal capability. + capability (list): Capability configuration of the model. + audio_api_key (str | None): API key used for audio transcription. + enable_audio_transcription (bool): Whether audio transcription is enabled. + """ + + def __init__( + self, + db: Session, + api_config: ModelInfo | None = None, + audio_api_key: Optional[str] = None, + enable_audio_transcription: bool = False, + ): """ - 初始化多模态服务 - + Initialize the multimodal service. + Args: - db: 数据库会话 - provider: 模型提供商(dashscope, bedrock, anthropic, openai 等) - api_key: API 密钥(用于音频转文本) - enable_audio_transcription: 是否启用音频转文本 - is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式) + db (Session): Database session. + api_config (ModelApiKey | None): Model API configuration. + audio_api_key (str | None): API key for audio transcription. + enable_audio_transcription (bool): Enable audio transcription. """ self.db = db - self.provider = provider.lower() - self.api_key = api_key + self.api_config = api_config + if self.api_config is not None: + self.model_api_key = api_config.api_key + self.provider = api_config.provider.lower() + self.is_omni = api_config.is_omni + self.capability = api_config.capability + self.audio_api_key = audio_api_key self.enable_audio_transcription = enable_audio_transcription - self.is_omni = is_omni async def process_files( self, - files: Optional[List[FileInput]] + end_user_id: uuid.UUID | str, + files: Optional[List[FileInput]], + ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: + end_user_id: 用户ID files: 文件输入列表 Returns: @@ -319,6 +346,8 @@ class MultimodalService: """ if not files: return [] + if isinstance(end_user_id, uuid.UUID): + end_user_id = str(end_user_id) # 获取对应的策略 # dashscope 的 omni 模型使用 OpenAI 兼容格式 @@ -333,19 +362,25 @@ class MultimodalService: result = [] for idx, file in enumerate(files): strategy = strategy_class(file) + if not file.url: + file.url = await self.get_file_url(file) try: - if file.type == FileType.IMAGE: + if file.type == FileType.IMAGE and "vision" in self.capability: content = await self._process_image(file, strategy) result.append(content) + self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: content = await self._process_document(file, strategy) result.append(content) - elif file.type == FileType.AUDIO: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) + elif file.type == FileType.AUDIO and "audio" in self.capability: content = await self._process_audio(file, strategy) result.append(content) - elif file.type == FileType.VIDEO: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) + elif file.type == FileType.VIDEO and "video" in self.capability: content = await self._process_video(file, strategy) result.append(content) + self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -355,7 +390,8 @@ class MultimodalService: "file_index": idx, "file_type": file.type, "error": str(e) - } + }, + exc_info=True ) # 继续处理其他文件,不中断整个流程 result.append({ @@ -366,6 +402,17 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result + def write_perceptual_memory( + self, + end_user_id: str, + file_type: str, + file_url: str, + file_message: dict + ): + """写入感知记忆""" + if end_user_id and self.api_config: + write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) + async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理图片文件 @@ -387,43 +434,6 @@ class MultimodalService: "text": f"[图片处理失败: {str(e)}]" } - @staticmethod - async def _download_and_encode_image(url: str) -> tuple[str, str]: - """ - 下载图片并转换为 base64 - - Args: - url: 图片 URL - - Returns: - tuple: (base64_data, media_type) - """ - from mimetypes import guess_type - - # 下载图片 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) - response.raise_for_status() - - # 获取图片数据 - image_data = response.content - - # 确定 media type - content_type = response.headers.get("content-type") - if content_type and content_type.startswith("image/"): - media_type = content_type - else: - # 从 URL 推断 - guessed_type, _ = guess_type(url) - media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - - # 转换为 base64 - base64_data = base64.b64encode(image_data).decode("utf-8") - - logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - - return base64_data, media_type - async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) @@ -436,7 +446,6 @@ class MultimodalService: Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: - # 远程文档暂不支持提取 return { "type": "text", "text": f"\n{await self._extract_document_text(file)}\n" @@ -471,12 +480,12 @@ class MultimodalService: # 如果启用音频转文本且有 API Key transcription = None - if self.enable_audio_transcription and self.api_key: + if self.enable_audio_transcription and self.audio_api_key: logger.info(f"开始音频转文本: {url}") if self.provider == "dashscope": - transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) + transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key) elif self.provider == "openai": - transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) + transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key) else: logger.warning(f"Provider {self.provider} 不支持音频转文本") diff --git a/api/app/services/prompt/perceptual_summary_system.jinja2 b/api/app/services/prompt/perceptual_summary_system.jinja2 new file mode 100644 index 00000000..ee5d3eb5 --- /dev/null +++ b/api/app/services/prompt/perceptual_summary_system.jinja2 @@ -0,0 +1,53 @@ +{% raw %}You are a professional information extraction system. + +Your task is to analyze the provided document content and generate structured metadata. + +Extract the following fields: + +* **summary**: A concise summary of the document in 2–4 sentences. +* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings. +* **topic**: The primary topic of the document expressed as a short phrase (3–8 words). +* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.). + +STRICT RULES: + +1. Output MUST be valid JSON. +2. Do NOT output markdown. +3. Do NOT output explanations. +4. Do NOT output any text before or after the JSON. +5. The JSON MUST contain EXACTLY these four keys: + * summary + * keywords + * topic + * domain{% endraw %} +{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %} +{% if file_type == 'audio' %} * speaker_count {% endif %} +{% if file_type == 'document' %} * section_count + * title + * first_line +{% endif %} +{% raw %} +6. `keywords` MUST be a JSON array of strings. +7. If the document content is insufficient, infer the best possible answer based on context. +8. Ensure the JSON is syntactically correct. +{% endraw %} +9. Output using the language {{ language }} +{% raw %} +Required JSON format: + +{ +"summary": "string", +"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"], +"topic": "string", +"domain": "string", +{% endraw %} +{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %} +{% if file_type == 'document' %} "section_count": integer +"title": "string", +"first_line": "string" +{% endif %} +{% if file_type == 'audio' %} "speaker_count": integer {% endif %} +{% raw %} +} + +Now analyze the following document and return the JSON result.{% endraw %} diff --git a/api/app/tasks.py b/api/app/tasks.py index 6fd9c954..5e1550bd 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,6 +1,5 @@ import asyncio -import json -import logging +import hashlib import os import re import shutil @@ -11,20 +10,48 @@ from datetime import datetime, timezone from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional -from uuid import UUID import redis -import requests from redis.exceptions import RedisError -logger = logging.getLogger(__name__) +# Import a unified Celery instance +from app.celery_app import celery_app +from app.core.config import settings +from app.core.logging_config import get_logger +from app.core.rag.crawler.web_crawler import WebCrawler +from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb +from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.cv_model import QWenCV +from app.core.rag.llm.embedding_model import OpenAIEmbed +from app.core.rag.llm.sequence2txt_model import QWenSeq2txt +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.prompts.generator import question_proposal +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( + ElasticSearchVectorFactory, +) +from app.db import get_db, get_db_context +from app.models import Document, File, Knowledge +from app.schemas import document_schema, file_schema +from app.schemas.model_schema import ModelInfo +from app.services.memory_agent_service import MemoryAgentService +from app.services.memory_perceptual_service import MemoryPerceptualService +from app.utils.config_utils import resolve_config_id +from app.utils.redis_lock import RedisLock + +logger = get_logger(__name__) # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连 -_sync_redis_pool: redis.ConnectionPool = None +_sync_redis_pool: redis.ConnectionPool | None = None -def _get_or_create_redis_pool() -> redis.ConnectionPool: + +def _get_or_create_redis_pool() -> redis.ConnectionPool | None: """获取或创建 Redis 连接池(懒初始化)""" global _sync_redis_pool if _sync_redis_pool is None: @@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool: return None return _sync_redis_pool + def get_sync_redis_client() -> Optional[redis.StrictRedis]: """获取同步 Redis 客户端(使用连接池) @@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: pool = _get_or_create_redis_pool() if pool is None: return None - + client = redis.StrictRedis(connection_pool=pool) # 验证连接可用性 client.ping() @@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True) return None -# Import a unified Celery instance -from app.celery_app import celery_app -from app.core.config import settings -from app.core.rag.crawler.web_crawler import WebCrawler -from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb -from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache -from app.core.rag.integrations.feishu.client import FeishuAPIClient -from app.core.rag.integrations.feishu.models import FileInfo -from app.core.rag.integrations.yuque.client import YuqueAPIClient -from app.core.rag.integrations.yuque.models import YuqueDocInfo -from app.core.rag.llm.chat_model import Base -from app.core.rag.llm.cv_model import QWenCV -from app.core.rag.llm.embedding_model import OpenAIEmbed -from app.core.rag.llm.sequence2txt_model import QWenSeq2txt -from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.prompts.generator import question_proposal -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( - ElasticSearchVectorFactory, -) -from app.db import get_db, get_db_context -from app.models.document_model import Document -from app.models.file_model import File -from app.models.knowledge_model import Knowledge -from app.schemas import document_schema, file_schema -from app.services.memory_agent_service import MemoryAgentService -from app.utils.config_utils import resolve_config_id + +def set_asyncio_event_loop(): + """Set the asyncio event loop for the current thread.""" + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop @celery_app.task(name="tasks.process_item") @@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID): vector_size = len(vts[0]) init_graphrag(task, vector_size) - async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, - chat_model, embedding_model, callback, with_resolution: bool = True, - with_community: bool = True, ) -> dict: + async def _run( + row: dict, + document_ids: list[str], + language: str, + parser_config: dict, + vector_service, + chat_model, + embedding_model, + callback, + with_resolution: bool = True, + with_community: bool = True + ) -> dict: await trio.sleep(5) # Delay for 10 seconds nonlocal progress_msg # Declare the use of an external progress_msg variable result = await run_graphrag_for_kb( @@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): with_community=with_community, ) ) + try: with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_task) @@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_community=with_community, ) ) + try: with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_task) @@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s # Log but continue - will fail later with proper error pass - async def _run() -> str: + async def _run() -> dict: with get_db_context() as db: service = MemoryAgentService() - return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, - storage_type, user_rag_memory_id) + return await service.read_memory( + end_user_id, + message, + history, + search_switch, + actual_config_id, db, + storage_type, user_rag_memory_id + ) try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str, +def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, + user_rag_memory_id: str, language: str = "zh") -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: @@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s Raises: Exception on failure """ - from app.core.logging_config import get_logger - logger = get_logger(__name__) - logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}") + logger.info( + f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " + f"config_id={config_id} (type: {type(config_id).__name__}), " + f"storage_type={storage_type}, language={language}") start_time = time.time() # Convert config_id to UUID @@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100*'-') + print(100 * '-') print(actual_config_id) - print(100*'-') + print(100 * '-') logger.info( f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error( + f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s async def _run() -> str: with get_db_context() as db: logger.info( - f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") + f"[CELERY WRITE] Executing MemoryAgentService.write_memory " + f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) @@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s return result try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s } -def reflection_engine() -> None: - """Empty function placeholder for timed background reflection. - - Intentionally left blank; replace with real reflection logic later. - """ - import asyncio - - from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) - - -@celery_app.task(name="app.core.memory.agent.reflection.timer") -def reflection_timer_task() -> None: - """Periodic Celery task that invokes reflection_engine. - - Raises an exception on failure. - """ - reflection_engine() - - # unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: @@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "workspace_id": workspace_id, "elapsed_time": elapsed_time, } + + @celery_app.task( name="app.tasks.write_all_workspaces_memory_task", bind=True, @@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.models.app_model import App from app.models.end_user_model import EndUser from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all - api_logger = get_api_logger() - with get_db_context() as db: try: # 获取所有活跃的工作空间 @@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: ).all() if not workspaces: - api_logger.warning("没有找到活跃的工作空间") + logger.warning("没有找到活跃的工作空间") return { "status": "SUCCESS", "message": "没有找到活跃的工作空间", @@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "workspace_results": [] } - api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") + logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") all_workspace_results = [] # 遍历每个工作空间 for workspace in workspaces: workspace_id = workspace.id - api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") + logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") try: # 1. 查询当前workspace下的所有app(仅未删除的) @@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), }) - api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0") + logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0") continue # 2. 查询所有app下的end_user_id(去重) @@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: }) except Exception as e: # 记录单个用户查询失败,但继续处理其他用户 - api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") + logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") end_user_details.append({ "end_user_id": str(end_user_id), "total": 0, @@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "created_at": memory_increment.created_at.isoformat(), }) - api_logger.info( + logger.info( f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}" ) except Exception as e: db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间 - api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") + logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") all_workspace_results.append({ "workspace_id": str(workspace_id), "workspace_name": workspace.name, @@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: } except Exception as e: - api_logger.error(f"记忆增量统计任务执行失败: {str(e)}") + logger.error(f"记忆增量统计任务执行失败: {str(e)}") return { "status": "FAILURE", "error": str(e), @@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.repositories.end_user_repository import EndUserRepository from app.services.user_memory_service import UserMemoryService - logger = get_logger(__name__) logger.info("开始执行记忆缓存重新生成定时任务") service = UserMemoryService() @@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.models.workspace_model import Workspace from app.services.memory_reflection_service import ( MemoryReflectionService, WorkspaceAppService, ) - api_logger = get_api_logger() - with get_db_context() as db: try: # 获取所有工作空间 @@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: # 遍历每个工作空间 for workspace in workspaces: workspace_id = workspace.id - api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") try: reflection_service = MemoryReflectionService(db) @@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: workspace_reflection_results = [] for data in result['apps_detailed_info']: - if data['memory_configs'] == []: + if not data['memory_configs']: continue releases = data['releases'] @@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str( user['app_id']): # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") reflection_result = await reflection_service.start_reflection_from_data( config_data=config, @@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]: "reflection_results": workspace_reflection_results }) - api_logger.info( + logger.info( f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: db.rollback() # Rollback failed transaction to allow next query - api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), "error": str(e), @@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } except Exception as e: - api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + logger.error(f"工作空间反思任务执行失败: {str(e)}") return { "status": "FAILURE", "error": str(e), @@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.services.memory_forget_service import MemoryForgetService - api_logger = get_api_logger() - with get_db_context() as db: try: - api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") forget_service = MemoryForgetService() # 运行遗忘周期 + # FIXME: MemeoryForgetService report = await forget_service.trigger_forgetting( db=db, end_user_id=None, # 处理所有组 @@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di duration = time.time() - start_time - api_logger.info( + logger.info( f"遗忘周期定时任务完成: " f"融合 {report['merged_count']} 对节点, " f"失败 {report['failed_count']} 对, " @@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di except Exception as e: duration = time.time() - start_time - api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) + logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) return { "status": "FAILED", @@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di finally: loop.close() + # ============================================================================= # Long-term Memory Storage Tasks (Batched Write Strategies) # ============================================================================= @@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from sqlalchemy import func, select + from sqlalchemy import select - from app.core.logging_config import get_logger from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage from app.repositories.implicit_emotions_storage_repository import ( ImplicitEmotionsStorageRepository, @@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.implicit_memory_service import ImplicitMemoryService - logger = get_logger(__name__) logger.info("开始执行隐性记忆和情绪数据更新定时任务") total_users = 0 @@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: for end_user_id in refresh_iter: logger.info(f"开始处理用户: {end_user_id}") user_start_time = time.time() - + implicit_success = False emotion_success = False errors = [] @@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: failed += 1 user_elapsed = time.time() - user_start_time - + # 记录用户处理结果 user_result = { "end_user_id": end_user_id, @@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.repositories.implicit_emotions_storage_repository import ( ImplicitEmotionsStorageRepository, ) from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.implicit_memory_service import ImplicitMemoryService - logger = get_logger(__name__) logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}") initialized = 0 @@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time @@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ 默认生成中文(zh)兴趣分布数据。 Args: + self: task object end_user_ids: 需要检查的用户ID列表 Returns: @@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE from app.services.memory_agent_service import MemoryAgentService - logger = get_logger(__name__) logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}") initialized = 0 @@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time @@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +@celery_app.task( + name="app.tasks.write_perceptual_memory", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) +def write_perceptual_memory( + self, + end_user_id: str, + model_api_config: dict, + file_type: str, + file_url: str, + file_message: dict +): + """ + Write perceptual memory for a user into PostgreSQL and Neo4j. + + This task generates or updates the user's perceptual memory + in the backend databases. It is intended to be executed asynchronously + via Celery. + + Args: + end_user_id (uuid.UUID): The unique identifier of the end user. + model_api_config (ModelInfo): API configuration for the model + used to generate perceptual memory. + file_type (str): The file type + file_url (url): The url of file + file_message (dict): The file message containing details about the file + to be processed. + + Returns: + None + """ + file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() + set_asyncio_event_loop() + with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): + model_info = ModelInfo(**model_api_config) + with get_db_context() as db: + memory_perceptual_service = MemoryPerceptualService(db) + return asyncio.run(memory_perceptual_service.generate_perceptual_memory( + end_user_id, + model_info, + file_type, + file_url, + file_message, + )) diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py new file mode 100644 index 00000000..99f62d84 --- /dev/null +++ b/api/app/utils/redis_lock.py @@ -0,0 +1,61 @@ +import redis +import uuid +import time + +UNLOCK_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) +else + return 0 +end +""" + + +class RedisLock: + def __init__( + self, + key: str, + redis_client: redis.StrictRedis, + expire: int = 60, + retry_interval: float = 0.1, + timeout: float = 30 + + ): + self.key = key + self.expire = expire + self.value = str(uuid.uuid4()) + self._locked = False + self.retry_interval = retry_interval + self.timeout = timeout + self.redis_client = redis_client + + def acquire(self) -> bool: + start = time.time() + while True: + ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) + if ok: + self._locked = True + return True + if time.time() - start >= self.timeout: + return False + time.sleep(self.retry_interval) + + def release(self): + if not self._locked: + return + self.redis_client.eval( + UNLOCK_SCRIPT, + 1, + self.key, + self.value + ) + self._locked = False + + def __enter__(self): + ok = self.acquire() + if not ok: + raise RuntimeError(f"Get redis lock timeout: {self.key}") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release()