Merge branch 'develop' into release/v0.2.7

This commit is contained in:
Ke Sun
2026-03-16 15:47:00 +08:00
committed by GitHub
155 changed files with 13164 additions and 1796 deletions

View File

@@ -162,6 +162,44 @@ class Settings:
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# ========================================================================
# Internationalization (i18n) Configuration
# ========================================================================
# Default language for API responses
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
# Supported languages (comma-separated)
I18N_SUPPORTED_LANGUAGES: list[str] = [
lang.strip()
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
if lang.strip()
]
# Core locales directory (community edition)
# Use absolute path to work from any working directory
I18N_CORE_LOCALES_DIR: str = os.getenv(
"I18N_CORE_LOCALES_DIR",
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
)
# Premium locales directory (enterprise edition, optional)
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
# Enable translation cache
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
# LRU cache size for hot translations
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
# Enable hot reload of translation files
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
# Fallback language when translation is missing
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
# Log missing translations
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

View File

@@ -2,15 +2,37 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState:
"""开始节点 - 提取内容并保持状态信息"""
"""
Start node - Extract content and maintain state information
Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information.
Args:
state: ReadState containing messages and other state data
Returns:
ReadState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
# Return content and maintain all state information
return {"data": content}
def content_input_write(state: WriteState) -> WriteState:
"""开始节点 - 提取内容并保持状态信息"""
"""
Start node - Extract content and maintain state information for write operations
Extracts the content from the first message in the state for write operations.
Args:
state: WriteState containing messages and other state data
Returns:
WriteState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}
# Return content and maintain all state information
return {"data": content}

View File

@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
"""
Problem processing node service class
Handles problem decomposition and extension operations using LLM services.
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
"""
Problem decomposition node
Breaks down complex user queries into smaller, more manageable sub-problems.
Uses LLM to analyze the input and generate structured problem decomposition
with question types and reasoning.
Args:
state: ReadState containing user input and configuration
Returns:
ReadState: Updated state with problem decomposition results
"""
# 从状态中获取数据
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
# Validate structured response
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False)
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
exc_info=True
)
# 提供更详细的错误信息
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
# Create default empty result
result = {
"context": json.dumps([], ensure_ascii=False),
"original": content,
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
}
}
# 返回更新后的状态,包含spit_context字段
# Return updated state including spit_context field
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
"""
Problem extension node
Extends the decomposed problems from Split_The_Problem node by generating
additional related questions and organizing them by original question.
Uses LLM to create comprehensive question extensions for better memory retrieval.
Args:
state: ReadState containing decomposed problems and configuration
Returns:
ReadState: Updated state with extended problem results
"""
# Get original data and decomposition results
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# 验证结构化响应
# Validate structured response
if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {}
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
exc_info=True
)
# 提供更详细的错误信息
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),

View File

@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
@@ -48,6 +60,19 @@ async def rag_config(state):
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results.
Args:
state: Current state containing configuration
question: Question to search for
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
async def llm_infomation(state: ReadState) -> ReadState:
"""
Get LLM configuration information from state
Retrieves model configuration details including model ID and tenant ID
from the memory configuration in the current state.
Args:
state: ReadState containing memory configuration
Returns:
ReadState: Model configuration as Pydantic model
"""
memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id
# 使用现有的 memory_config 而不是重新查询数据库
# 或者使用线程安全的数据库访问
# Use existing memory_config instead of re-querying database
# or use thread-safe database access
with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
async def clean_databases(data) -> str:
"""
简化的数据库搜索结果清理函数
Simplified database search result cleaning function
Processes and cleans search results from various sources including
reranked results and time-based search results. Extracts text content
from structured data and returns as formatted string.
Args:
data: 搜索结果数据
data: Search result data (can be string, dict, or other types)
Returns:
清理后的内容字符串
str: Cleaned content string
"""
try:
# 解析JSON字符串
# Parse JSON string
if isinstance(data, str):
try:
data = json.loads(data)
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
if not isinstance(data, dict):
return str(data)
# 获取结果数据
# Get result data
# with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data)
if not isinstance(results, dict):
return str(results)
# 收集所有内容
# Collect all content
content_list = []
# 处理重排序结果
# Process reranked results
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
# 处理时间搜索结果
# Process time search results
time_search = results.get('time_search', {})
if time_search:
if isinstance(time_search, dict):
@@ -128,7 +169,7 @@ async def clean_databases(data) -> str:
elif isinstance(time_search, list):
content_list.extend(time_search)
# 提取文本内容
# Extract text content
text_parts = []
for item in content_list:
if isinstance(item, dict):
@@ -146,10 +187,19 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
"""
Retrieve information using simplified search approach
Processes extended problems from previous nodes and performs retrieval
using either RAG or hybrid search based on storage type. Handles concurrent
processing of multiple questions and deduplicates results.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
@@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
# Create async task to process individual questions
async def process_question_nodes(idx, question):
try:
# Prepare search parameters based on storage type
@@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
}
}
# 并发处理所有问题
# Process all questions concurrently
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
@@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
"""
Advanced retrieve function using LangChain agents and tools
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
to perform sophisticated information retrieval. Supports both RAG and traditional
memory storage approaches with concurrent processing and result deduplication.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
# Get end_user_id from state
import time
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
@@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState:
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题
# Create async task to process individual questions
import asyncio
# 在模块级别定义信号量,限制最大并发数
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
# Define semaphore at module level to limit maximum concurrency
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
async def process_question(idx, question):
async with SEMAPHORE: # 限制并发
async with SEMAPHORE: # Limit concurrency
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
# Use asyncio to run synchronous agent.invoke in thread pool
import asyncio
response = await asyncio.get_event_loop().run_in_executor(
None,
@@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState:
}
}
# 并发处理所有问题
# Process all questions concurrently
import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)

View File

@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
"""
Summary node service class
Handles summary generation operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
generating summaries from retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
summary_service = SummaryNodeService()
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings specifically for summary generation.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary with knowledge base settings
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
@@ -54,6 +75,23 @@ async def rag_config(state):
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach for summary generation
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results for summary processing.
Args:
state: Current state containing configuration
question: Question to search for in knowledge base
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
- retrieval_knowledge: List of retrieved knowledge chunks
- clean_content: Formatted content string
- cleaned_query: Processed query string
- raw_results: Raw retrieval results
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
async def summary_history(state: ReadState) -> ReadState:
"""
Retrieve conversation history for summary context
Gets the conversation history for the current user to provide context
for summary generation operations.
Args:
state: ReadState containing end_user_id
Returns:
ReadState: Conversation history data
"""
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数,包含更好的错误处理和数据验证
Enhanced summary_llm function with better error handling and data validation
Generates summaries using LLM with structured output. Includes fallback mechanisms
for handling LLM failures and provides robust error recovery.
Args:
state: ReadState containing current context
history: Conversation history for context
retrieve_info: Retrieved information to summarize
template_name: Jinja2 template name for prompt generation
operation_name: Type of operation (summary, input_summary, retrieve_summary)
response_model: Pydantic model for structured output
search_mode: Search mode flag ("0" for simple, "1" for complex)
Returns:
str: Generated summary text or fallback message
"""
data = state.get("data", '')
# 构建系统提示词
# Build system prompt
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
retrieve_info=retrieve_info
)
try:
# 使用优化的LLM服务进行结构化输出
# Use optimized LLM service for structured output
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
# Validate structured response
if structured is None:
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
# Extract answer based on operation type
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else:
# 处理RetrieveSummaryResponse
# Handle RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
# Validate answer is not empty
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
# Try unstructured output as fallback
try:
logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple(
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
if response and response.strip():
# 简单清理响应
# Simple response cleaning
cleaned_response = response.strip()
# 移除可能的JSON标记
# Remove possible JSON markers
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
"""
Save summary results to Redis session storage
Stores the generated summary and user query in Redis for session management
and conversation history tracking.
Args:
state: ReadState containing user and query information
aimessages: Generated summary message to save
Returns:
ReadState: Updated state after saving to Redis
"""
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
"""
Format summary results for different output types
Creates structured output formats for both input summary and retrieval summary
operations, including metadata and intermediate results for frontend display.
Args:
state: ReadState containing storage and user information
aimessages: Generated summary message
raw_results: Raw search/retrieval results
Returns:
tuple: (input_summary, retrieve_summary) formatted result dictionaries
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
async def Input_Summary(state: ReadState) -> ReadState:
"""
Generate quick input summary from retrieved information
Performs fast retrieval and generates a quick summary response for user queries.
This function prioritizes speed by only searching summary nodes and provides
immediate feedback to users.
Args:
state: ReadState containing user query, storage configuration, and context
Returns:
ReadState: Dictionary containing summary results with status and metadata
"""
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
@@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
async def Retrieve_Summary(state: ReadState) -> ReadState:
"""
Generate comprehensive summary from retrieved expansion issues
Processes retrieved expansion issues and generates a detailed summary using LLM.
This function handles complex retrieval results and provides comprehensive answers
based on expanded query results.
Args:
state: ReadState containing retrieve data with expansion issues
Returns:
ReadState: Dictionary containing comprehensive summary results
"""
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
@@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary(state: ReadState) -> ReadState:
"""
Generate final comprehensive summary from verified data
Creates the final summary using verified expansion issues and conversation history.
This function processes verified data to generate the most comprehensive and
accurate response to user queries.
Args:
state: ReadState containing verified data and query information
Returns:
ReadState: Dictionary containing final summary results
"""
start = time.time()
query = state.get("data", '')
verify = state.get("verify", '')
@@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
"""
Generate fallback summary when normal summary process fails
Provides a fallback summary generation mechanism when the standard summary
process encounters errors or fails to produce satisfactory results. Uses
a specialized failure template to handle edge cases.
Args:
state: ReadState containing verified data and failure context
Returns:
ReadState: Dictionary containing fallback summary results
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)

View File

@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
"""
Verification node service class
Handles data verification operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
verifying and validating retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
"""
Process verification results and generate output format
Transforms VerificationResult objects into structured output format suitable
for frontend consumption. Handles conversion of VerificationItem objects to
dictionary format and adds metadata for tracking.
Args:
state: ReadState containing storage and user configuration
messages_deal: VerificationResult containing verification outcomes
Returns:
dict: Formatted verification result with status and metadata
"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# VerificationItem 对象转换为字典列表
# Convert VerificationItem objects to dictionary list
verified_data = []
if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue:
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
# Generate JSON schema to guide LLM output format
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
timeout=150.0 # 150 second timeout
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:

View File

@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
"""
Create and return a LangGraph workflow for memory reading operations
Builds a state graph workflow that handles memory retrieval, problem analysis,
verification, and summarization. The workflow includes nodes for content input,
problem splitting, retrieval, verification, and various summary operations.
Yields:
StateGraph: Compiled LangGraph workflow for memory reading
Raises:
Exception: If workflow creation fails
"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
@@ -48,7 +60,7 @@ async def make_read_graph():
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
# Add edges to define workflow flow
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
@@ -63,7 +75,7 @@ async def make_read_graph():
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
# Compile workflow
graph = workflow.compile()
yield graph
@@ -72,108 +84,3 @@ async def make_read_graph():
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
import time
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
# # 过滤掉空值
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
#
# # 优化搜索结果
# print("=== 开始优化搜索结果 ===")
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
# result=reorder_output_results(optimized_outputs)
# # 保存优化后的结果到文件
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
# import json
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,13 +1,13 @@
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = get_agent_logger(__name__)
counter = COUNTState(limit=3)
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status']
status = state.get('verify', '')['status']
# loop_count = counter.get_total()
if "success" in status:
# counter.reset()
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
# counter.reset()
# counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None

View File

@@ -2,77 +2,104 @@ import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
async def write(
storage_type,
end_user_id,
user_message,
ai_message,
user_rag_memory_id,
actual_end_user_id,
actual_config_id,
long_term_messages=None
):
"""
写入记忆(支持结构化消息)
Write memory with structured message support
Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing.
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: Terminal user identifier
user_message: User message content
ai_message: AI response content
user_rag_memory_id: RAG memory identifier
actual_end_user_id: Actual user identifier for storage
actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing
逻辑说明:
- RAG 模式:组合 user_message ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
Logic explanation:
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j mode: Uses structured message lists
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
3. Each message is converted to independent Chunk, preserving speaker field
"""
db = next(get_db())
try:
if long_term_messages is None:
long_term_messages = []
with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
# Neo4j mode: Use structured message lists
structured_messages = []
# 始终添加用户消息(如果不为空)
# Always add user message (if not empty)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
# Only add assistant message when AI reply is not empty
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages,使用它替代 structured_messages
# If long_term_messages provided, use it to replace structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
# If it's a JSON string, parse it first
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
# If no messages, return directly
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
Manages conversation data based on a sliding window approach. When the window
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id(如果 memory_config 是对象,提取 config_id否则直接使用
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time):
"""
Process memory storage based on time intervals and write to Neo4j
Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
time: Time interval for data retrieval
"""
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
format_messages = long_time_data
messages = []
memory_config = memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
message = json.loads(i['Query'])
messages += message
if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Aggregation judgment function: determine if input sentence and historical messages describe the same event
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval.
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings
Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output
"""
history = None
try:
# 1. 获取历史会话数据(使用新方法)
# 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
{"role": msg.role, "content": msg.content}
for msg in output_value
]
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}
}

View File

@@ -2,41 +2,53 @@ import asyncio
import json
from datetime import datetime, timedelta
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
def extract_tool_message_content(response):
"""从agent响应中提取ToolMessage内容和工具名称"""
"""
Extract ToolMessage content and tool names from agent response
Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data.
Args:
response: Agent response dictionary containing messages
Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool
- content: Parsed tool execution result (JSON or raw text)
"""
messages = response.get('messages', [])
for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# 这是一个ToolMessage
# This is a ToolMessage
tool_content = message.content
tool_name = None
# 尝试获取工具名称
# Try to get tool name
if hasattr(message, 'name'):
tool_name = message.name
elif hasattr(message, 'tool_name'):
tool_name = message.tool_name
try:
# 解析JSON内容
# Parse JSON content
parsed_content = json.loads(tool_content)
return {
'tool_name': tool_name,
'content': parsed_content
}
except json.JSONDecodeError:
# 如果不是JSON格式直接返回内容
# If not JSON format, return content directly
return {
'tool_name': tool_name,
'content': tool_content
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
"""
Input schema for time retrieval tool
Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters.
Attributes:
context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user
"""
context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results.
Args:
end_user_id: User identifier for scoping search results
Returns:
function: Configured TimeRetrievalWithGroupId tool function
"""
def clean_temporal_result_fields(data):
"""
清理时间搜索结果中不需要的字段,并修改结构
Clean unnecessary fields from temporal search results and modify structure
Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability.
Args:
data: 要清理的数据
data: Data to be cleaned (dict, list, or other types)
Returns:
清理后的数据
Cleaned data with unnecessary fields removed
"""
# 需要过滤的字段列表
# List of fields to filter out
fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value)
# 进一步将内部的 statements 改为 time_search
# Further change internal statements to time_search
if 'statements' in cleaned_value:
cleaned['results'] = {
'time_search': cleaned_value['statements']
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
Performs time-based search operations with automatic metadata filtering. Supports
flexible date range specification and provides clean, user-friendly output.
Explicit parameters:
- context: Query context content
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results with temporal data
"""
async def _async_search():
# 使用传入的参数或默认值
# Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
# Basic time search
results = await search_by_temporal(
end_user_id=actual_end_user_id,
start_date=actual_start_date,
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
limit=10
)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
@tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
clean_output: bool = True) -> str:
"""
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询内容
- days_back: 向前搜索的天数默认7天
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- clean_output: 是否清理输出中的元数据字段
- end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
Performs combined keyword and temporal search operations with automatic metadata
filtering. Provides more targeted search results by combining content relevance
with time-based filtering.
Explicit parameters:
- context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results combining keyword and temporal data
"""
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# 关键词时间搜索
# Keyword time search
results = await search_by_keyword_temporal(
query_text=context,
end_user_id=end_user_id,
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
limit=15
)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
@@ -162,51 +216,60 @@ def create_time_retrieval_tool(end_user_id: str):
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
创建混合检索工具使用run_hybrid_search进行混合检索优化输出格式并过滤不需要的字段
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting.
Args:
memory_config: 内存配置对象
**search_params: 搜索参数,包含end_user_id, limit, include
memory_config: Memory configuration object containing LLM and search settings
**search_params: Search parameters including end_user_id, limit, include, etc.
Returns:
function: Configured HybridSearch tool function with async capabilities
"""
def clean_result_fields(data):
"""
递归清理结果中不需要的字段
Recursively clean unnecessary fields from results
Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size.
Args:
data: 要清理的数据(可能是字典、列表或其他类型)
data: Data to be cleaned (can be dict, list, or other types)
Returns:
清理后的数据
Cleaned data with unnecessary fields removed
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# List of fields to filter out
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
}
if isinstance(data, dict):
# 对字典进行清理
# Clean dictionary
cleaned = {}
for key, value in data.items():
if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
return cleaned
elif isinstance(data, list):
# 对列表中的每个元素进行清理
# Clean each element in list
return [clean_result_fields(item) for item in data]
else:
# 其他类型直接返回
# Return other types directly
return data
@tool
async def HybridSearch(
context: str,
@@ -216,57 +279,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
clean_output: bool = True # 新增:是否清理输出字段
clean_output: bool = True # New: whether to clean output fields
) -> str:
"""
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output.
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
clean_output: 是否清理输出中的元数据字段
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
rerank_alpha: Reranking weight parameter for result scoring
use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: Whether to use LLM-based reranking
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted comprehensive search results
"""
try:
# 导入run_hybrid_search函数
# Import run_hybrid_search function
from app.core.memory.src.search import run_hybrid_search
# 合并参数,优先使用传入的参数
# Merge parameters, prioritize passed parameters
final_params = {
"query_text": context,
"search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
"output_path": None, # Don't save to file
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank
}
# 执行混合检索
# Execute hybrid retrieval
raw_results = await run_hybrid_search(**final_params)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_result_fields(raw_results)
else:
cleaned_results = raw_results
# 格式化返回结果
# Format return results
formatted_results = {
"search_query": context,
"search_type": search_type,
"results": cleaned_results
}
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
@@ -275,38 +344,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments.
Args:
memory_config: 内存配置对象
**search_params: 搜索参数
memory_config: Memory configuration object containing search settings
**search_params: Search parameters for configuration
Returns:
function: Configured HybridSearchSync tool function
"""
@tool
def HybridSearchSync(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion.
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted search results
"""
async def _async_search():
# 创建异步工具并执行
# Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({
"context": context,
@@ -315,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"end_user_id": end_user_id,
"clean_output": clean_output
})
return asyncio.run(_async_search())
return HybridSearchSync
return HybridSearchSync

View File

@@ -1,20 +1,28 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
async def format_parsing(messages: list, type: str = 'string'):
"""
格式化解析消息列表
Format and parse message lists into different output types
Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization.
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
messages: List of message objects from storage containing message data
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
Returns:
格式化后的消息列表
list: Formatted message list in the specified format
- 'string': List of formatted text messages with role prefixes
- 'dict': List of dictionaries mapping user messages to AI responses
"""
result = []
user=[]
ai=[]
user = []
ai = []
for message in messages:
hstory_messages = message['messages']
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role=="user":
if role == 'human' or role == "user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
if type == "dict":
if role == 'human' or role == "user":
user.append(content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
for key, values in zip(user, ai):
result.append({key: values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
"""
Parse messages from storage format into user-AI conversation pairs
Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage.
Args:
messages: List or dictionary containing stored message data with Query fields
Returns:
list: List of dictionaries containing user-AI message pairs for database storage
"""
user = []
ai = []
database = []
for message in messages:
Query = message['Query']
Query = json.loads(Query)
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
return database
async def agent_chat_messages(user_content,ai_content):
async def agent_chat_messages(user_content, ai_content):
"""
Create structured chat message format for agent conversations
Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects.
Args:
user_content: User's message content string
ai_content: AI's response content string
Returns:
list: List of structured message dictionaries with role and content fields
"""
messages = [
{
"role": "user",

View File

@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -42,10 +41,26 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence.
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages.
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
# asyncio.run(main())

View File

@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict):
'''
"""
Langgrapg Writing TypedDict
'''
"""
messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
data: str
language: str # 语言类型 ("zh" 中文, "en" 英文)
class ReadState(TypedDict):
"""
LangGraph 工作流状态定义
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict
problem_extension: dict
storage_type: str
user_rag_memory_id: str
llm_id: str
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict
retrieve: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict
SummaryFails: dict
summary: dict
class COUNTState:
"""
工作流对话检索内容计数器
@@ -99,6 +103,7 @@ class COUNTState:
self.total = 0
print("[COUNTState] 已重置为 0")
def deduplicate_entries(entries):
seen = set()
deduped = []
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
deduped.append(entry)
return deduped
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
return [convert_extended_question_to_question(item) for item in data]
else:
# 其他类型直接返回
return data
return data

View File

@@ -165,7 +165,9 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
)
if success:
logger.info("Successfully saved all data to Neo4j")

View File

@@ -0,0 +1,3 @@
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
__all__ = ["LabelPropagationEngine"]

View File

@@ -0,0 +1,484 @@
"""标签传播聚类引擎
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
支持两种模式:
- 全量初始化full_clustering首次运行对所有实体做完整 LPA 迭代
- 增量更新incremental_update新实体到达时只处理新实体及其邻居
"""
import logging
import uuid
from math import sqrt
from typing import Dict, List, Optional
from app.repositories.neo4j.community_repository import CommunityRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
if not v1 or not v2 or len(v1) != len(v2):
return 0.0
dot = sum(a * b for a, b in zip(v1, v2))
norm1 = sqrt(sum(a * a for a in v1))
norm2 = sqrt(sum(b * b for b in v2))
if norm1 == 0 or norm2 == 0:
return 0.0
return dot / (norm1 * norm2)
def _weighted_vote(
neighbors: List[Dict],
self_embedding: Optional[List[float]],
) -> Optional[str]:
"""
加权多数投票,选出得票最高的社区。
权重 = 语义相似度name_embedding 余弦)* activation_value 加成
没有 community_id 的邻居不参与投票。
"""
votes: Dict[str, float] = {}
for nb in neighbors:
cid = nb.get("community_id")
if not cid:
continue
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
act = nb.get("activation_value") or 0.5
# 语义相似度权重 0.6,激活值权重 0.4
weight = 0.6 * sem + 0.4 * act
votes[cid] = votes.get(cid, 0.0) + weight
if not votes:
return None
return max(votes, key=votes.__getitem__)
class LabelPropagationEngine:
"""标签传播聚类引擎"""
def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
# ──────────────────────────────────────────────────────────────────────────
async def run(
self,
end_user_id: str,
new_entity_ids: Optional[List[str]] = None,
) -> None:
"""
统一入口:自动判断全量还是增量。
- 若该用户尚无 Community 节点 → 全量初始化
- 否则 → 增量更新(仅处理 new_entity_ids
"""
has_communities = await self.repo.has_communities(end_user_id)
if not has_communities:
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
await self.full_clustering(end_user_id)
else:
if new_entity_ids:
logger.info(
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
)
await self.incremental_update(new_entity_ids, end_user_id)
async def full_clustering(self, end_user_id: str) -> None:
"""
全量标签传播初始化。
1. 拉取所有实体,初始化每个实体为独立社区
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
3. 直到标签不再变化或达到 MAX_ITERATIONS
4. 将最终标签写入 Neo4j
"""
entities = await self.repo.get_all_entities(end_user_id)
if not entities:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return
# 初始化:每个实体持有自己 id 作为社区标签
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in entities
}
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
# 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
for entity in entities:
eid = entity["id"]
# 直接从缓存取邻居,不再发起 Neo4j 查询
neighbors = neighbors_cache.get(eid, [])
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束迭代")
break
# 将最终标签写入 Neo4j
await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values()))
logger.info(
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体,开始后处理合并"
)
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id)
logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str
) -> None:
"""
增量更新:只处理新实体及其邻居,不重跑全图。
1. 对每个新实体查询邻居
2. 加权多数投票决定社区归属
3. 若邻居无社区 → 创建新社区
4. 若邻居分属多个社区 → 评估是否合并
"""
for entity_id in new_entity_ids:
await self._process_single_entity(entity_id, end_user_id)
# ──────────────────────────────────────────────────────────────────────────
# 内部方法
# ──────────────────────────────────────────────────────────────────────────
async def _process_single_entity(
self, entity_id: str, end_user_id: str
) -> None:
"""处理单个新实体的社区分配。"""
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
# 查询自身 embedding从邻居查询结果中无法获取需单独查
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
if not neighbors:
# 孤立实体:创建单成员社区
new_cid = self._new_community_id()
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
return
# 统计邻居社区分布
community_ids_in_neighbors = set(
nb["community_id"] for nb in neighbors if nb.get("community_id")
)
target_cid = _weighted_vote(neighbors, self_embedding)
if target_cid is None:
# 邻居都没有社区,连同新实体一起创建新社区
new_cid = self._new_community_id()
await self.repo.upsert_community(new_cid, end_user_id)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
for nb in neighbors:
await self.repo.assign_entity_to_community(
nb["id"], new_cid, end_user_id
)
await self.repo.refresh_member_count(new_cid, end_user_id)
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata(new_cid, end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
await self.repo.refresh_member_count(target_cid, end_user_id)
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
# 若邻居分属多个社区,评估合并
if len(community_ids_in_neighbors) > 1:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata(target_cid, end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
) -> None:
"""
评估多个社区是否应合并。
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
合并时保留成员数最多的社区,其余成员迁移过来。
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.85
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
community_sizes: Dict[str, int] = {}
if len(community_ids) > BATCH_THRESHOLD:
# 批量查询:一次拉取所有社区成员
all_members = await self.repo.get_all_community_members_batch(
community_ids, end_user_id
)
for cid in community_ids:
members = all_members.get(cid, [])
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
else:
# 增量场景:逐个查询
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
# 找出应合并的社区对
to_merge: List[tuple] = []
cids = list(community_ids)
for i in range(len(cids)):
for j in range(i + 1, len(cids)):
sim = _cosine_similarity(
community_embeddings[cids[i]],
community_embeddings[cids[j]],
)
if sim > MERGE_THRESHOLD:
to_merge.append((cids[i], cids[j]))
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
# A≈B、B≈C 不代表 A≈C不能因传递性把 A/B/C 全部合并)
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
def get_root(x: str) -> str:
"""路径压缩,找到 x 当前所属的根社区。"""
while x in merged_into:
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
x = merged_into[x]
return x
for c1, c2 in to_merge:
root1, root2 = get_root(c1), get_root(c2)
if root1 == root2:
continue
# 用合并后的最新平均向量重新验证相似度
# 防止链式传递A≈B 合并后 B 的向量已更新C 必须和新 B 相似才能合并
current_sim = _cosine_similarity(
community_embeddings.get(root1),
community_embeddings.get(root2),
)
if current_sim <= MERGE_THRESHOLD:
# 合并后向量已漂移,不再满足阈值,跳过
logger.debug(
f"[Clustering] 跳过合并 {root1}{root2}"
f"当前相似度 {current_sim:.3f}{MERGE_THRESHOLD}"
)
continue
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
merged_into[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 合并后重新计算 keep 的平均向量(加权平均)
keep_emb = community_embeddings.get(keep)
dissolve_emb = community_embeddings.get(dissolve)
keep_size = community_sizes.get(keep, 0)
dissolve_size = community_sizes.get(dissolve, 0)
total_size = keep_size + dissolve_size
if keep_emb and dissolve_emb and total_size > 0:
dim = len(keep_emb)
community_embeddings[keep] = [
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
for i in range(dim)
]
community_embeddings[dissolve] = None
community_sizes[keep] = total_size
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
)
async def _flush_labels(
self, labels: Dict[str, str], end_user_id: str
) -> None:
"""将内存中的标签批量写入 Neo4j。"""
# 先创建所有唯一社区节点
unique_communities = set(labels.values())
for cid in unique_communities:
await self.repo.upsert_community(cid, end_user_id)
# 再批量分配实体
for entity_id, community_id in labels.items():
await self.repo.assign_entity_to_community(
entity_id, community_id, end_user_id
)
# 刷新成员数
for cid in unique_communities:
await self.repo.refresh_member_count(cid, end_user_id)
async def _get_entity_embedding(
self, entity_id: str, end_user_id: str
) -> Optional[List[float]]:
"""查询单个实体的 name_embedding。"""
try:
result = await self.connector.execute_query(
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
"RETURN e.name_embedding AS name_embedding",
eid=entity_id,
uid=end_user_id,
)
return result[0]["name_embedding"] if result else None
except Exception:
return None
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
@staticmethod
def _new_community_id() -> str:
return str(uuid.uuid4())

View File

@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__)
class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM"""
def __init__(
self,
llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"):
self,
llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"
):
"""Initialize the TripletExtractor with an LLM client
Args:
@@ -65,7 +65,8 @@ class TripletExtractor:
# Create messages for LLM
messages = [
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "system",
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "user", "content": prompt_content}
]
@@ -116,7 +117,8 @@ class TripletExtractor:
logger.error(f"Error processing statement: {e}", exc_info=True)
return TripletExtractionResponse(triplets=[], entities=[])
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
str, TripletExtractionResponse]:
"""Extract triplets and entities from statements
Args:

View File

@@ -1,11 +1,11 @@
"""
自我反思引擎实现
Self-Reflection Engine Implementation
该模块实现了记忆系统的自我反思功能,包括:
1. 基于时间的反思 - 根据时间周期触发反思
2. 基于事实的反思 - 检测记忆冲突并解决
3. 综合反思 - 整合多种反思策略
4. 反思结果应用 - 更新记忆库
This module implements the self-reflection functionality of the memory system, including:
1. Time-based reflection - Triggers reflection based on time cycles
2. Fact-based reflection - Detects and resolves memory conflicts
3. Comprehensive reflection - Integrates multiple reflection strategies
4. Reflection result application - Updates memory database
"""
import asyncio
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
)
from pydantic import BaseModel
# 配置日志
# Configure logging
_root_logger = logging.getLogger()
if not _root_logger.handlers:
logging.basicConfig(
@@ -49,35 +49,62 @@ else:
_root_logger.setLevel(logging.INFO)
class TranslationResponse(BaseModel):
"""翻译响应模型"""
"""Translation response model for language conversion"""
data: str
class ReflectionRange(str, Enum):
"""反思范围枚举"""
PARTIAL = "partial" # 从检索结果中反思
ALL = "all" # 从整个数据库中反思
"""
Reflection range enumeration
Defines the scope of data to be included in reflection operations.
"""
PARTIAL = "partial" # Reflect from retrieval results
ALL = "all" # Reflect from entire database
class ReflectionBaseline(str, Enum):
"""反思基线枚举"""
TIME = "TIME" # 基于时间的反思
FACT = "FACT" # 基于事实的反思
HYBRID = "HYBRID" # 混合反思
"""
Reflection baseline enumeration
Defines the strategy or approach used for reflection operations.
"""
TIME = "TIME" # Time-based reflection
FACT = "FACT" # Fact-based reflection
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
class ReflectionConfig(BaseModel):
"""反思引擎配置"""
"""
Reflection engine configuration
Defines all configuration parameters for the reflection engine including
operation modes, model settings, and evaluation criteria.
Attributes:
enabled: Whether reflection engine is enabled
iteration_period: Reflection cycle period (e.g., "3" hours)
reflexion_range: Scope of reflection (PARTIAL or ALL)
baseline: Reflection strategy (TIME, FACT, or HYBRID)
model_id: LLM model identifier for reflection operations
end_user_id: User identifier for scoped operations
output_example: Example output format for guidance
memory_verify: Enable memory verification checks
quality_assessment: Enable quality assessment evaluation
violation_handling_strategy: Strategy for handling violations
language_type: Language type for output ("zh" or "en")
"""
enabled: bool = False
iteration_period: str = "3" # 反思周期
iteration_period: str = "3" # Reflection cycle period
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
baseline: ReflectionBaseline = ReflectionBaseline.TIME
model_id: Optional[str] = None # 模型ID
model_id: Optional[str] = None # Model ID
end_user_id: Optional[str] = None
output_example: Optional[str] = None # 输出示例
output_example: Optional[str] = None # Output example
# 评估相关字段
memory_verify: bool = True # 记忆验证
quality_assessment: bool = True # 质量评估
violation_handling_strategy: str = "warn" # 违规处理策略
# Evaluation related fields
memory_verify: bool = True # Memory verification
quality_assessment: bool = True # Quality assessment
violation_handling_strategy: str = "warn" # Violation handling strategy
language_type: str = "zh"
class Config:
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
class ReflectionResult(BaseModel):
"""反思结果"""
"""
Reflection operation result
Contains comprehensive information about the outcome of a reflection operation
including success status, metrics, and execution details.
Attributes:
success: Whether the reflection operation succeeded
message: Descriptive message about the operation result
conflicts_found: Number of conflicts detected during reflection
conflicts_resolved: Number of conflicts successfully resolved
memories_updated: Number of memory entries updated in database
execution_time: Total time taken for the reflection operation
details: Additional details about the operation (optional)
"""
success: bool
message: str
conflicts_found: int = 0
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
class ReflectionEngine:
"""
自我反思引擎
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
Self-Reflection Engine
Responsible for executing memory system self-reflection operations including
conflict detection, conflict resolution, and memory updates. Supports multiple
reflection strategies and provides comprehensive result tracking.
The engine can operate in different modes:
- Time-based: Reflects on memories within specific time periods
- Fact-based: Detects and resolves factual conflicts in memories
- Hybrid: Combines multiple reflection strategies
Attributes:
config: Reflection engine configuration
neo4j_connector: Neo4j database connector
llm_client: Language model client for analysis
Various function handlers for data processing and prompt rendering
"""
def __init__(
@@ -115,18 +169,21 @@ class ReflectionEngine:
update_query: Optional[str] = None
):
"""
初始化反思引擎
Initialize reflection engine
Sets up the reflection engine with configuration and optional dependencies.
Uses lazy initialization to avoid circular imports and optimize startup time.
Args:
config: 反思引擎配置
neo4j_connector: Neo4j 连接器(可选)
llm_client: LLM 客户端(可选)
get_data_func: 获取数据的函数(可选)
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
conflict_schema: 冲突结果 Schema(可选)
reflexion_schema: 反思结果 Schema(可选)
update_query: 更新查询语句(可选)
config: Reflection engine configuration object
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
llm_client: LLM client instance (optional, will be created if not provided)
get_data_func: Function for retrieving data (optional, uses default if not provided)
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
conflict_schema: Schema for conflict result validation (optional)
reflexion_schema: Schema for reflection result validation (optional)
update_query: Query string for database updates (optional)
"""
self.config = config
self.neo4j_connector = neo4j_connector
@@ -137,14 +194,20 @@ class ReflectionEngine:
self.conflict_schema = conflict_schema
self.reflexion_schema = reflexion_schema
self.update_query = update_query
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
# 延迟导入以避免循环依赖
# Lazy import to avoid circular dependencies
self._lazy_init_done = False
def _lazy_init(self):
"""延迟初始化,避免循环导入"""
"""
Lazy initialization to avoid circular imports
Initializes dependencies only when needed, preventing circular import issues
and optimizing startup performance. Sets up default implementations for
any components not provided during construction.
"""
if self._lazy_init_done:
return
@@ -158,7 +221,7 @@ class ReflectionEngine:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(self.config.model_id)
elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端
# If llm_client is a string (model_id), use it to initialize the client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
@@ -172,10 +235,10 @@ class ReflectionEngine:
model_config = config_service.get_model_config(model_id)
extra_params={
"temperature": 0.2, # 降低温度提高响应速度和一致性
"max_tokens": 600, # 限制最大token数
"top_p": 0.8, # 优化采样参数
"stream": False, # 确保非流式输出以获得最快响应
"temperature": 0.2, # Lower temperature for faster response and consistency
"max_tokens": 600, # Limit maximum token count
"top_p": 0.8, # Optimize sampling parameters
"stream": False, # Ensure non-streaming output for fastest response
}
self.llm_client = OpenAIClient(RedBearModelConfig(
@@ -191,7 +254,7 @@ class ReflectionEngine:
if self.get_data_func is None:
self.get_data_func = get_data
# 导入get_data_statement函数
# Import get_data_statement function
if not hasattr(self, 'get_data_statement'):
self.get_data_statement = get_data_statement
@@ -223,13 +286,20 @@ class ReflectionEngine:
async def execute_reflection(self, host_id) -> ReflectionResult:
"""
执行完整的反思流程
Execute complete reflection workflow
Performs the full reflection process including data retrieval, conflict detection,
conflict resolution, and memory updates. This is the main entry point for
reflection operations.
Args:
host_id: 主机ID
host_id: Host identifier for scoping reflection operations
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive result of the reflection operation including
success status, conflict metrics, and execution time
"""
# 延迟初始化
# Lazy initialization
self._lazy_init()
if not self.config.enabled:
@@ -243,7 +313,7 @@ class ReflectionEngine:
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
try:
# 1. 获取反思数据
# 1. Get reflection data
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
if not reflexion_data:
return ReflectionResult(
@@ -252,7 +322,7 @@ class ReflectionEngine:
execution_time=asyncio.get_event_loop().time() - start_time
)
# 2. 检测冲突(基于事实的反思)
# 2. Detect conflicts (fact-based reflection)
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
conflict_list=[]
for i in conflict_data:
@@ -261,7 +331,7 @@ class ReflectionEngine:
conflicts_found=0
# 3. 解决冲突
# 3. Resolve conflicts
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
if not solved_data:
@@ -276,7 +346,7 @@ class ReflectionEngine:
logging.info(f"解决了 {conflicts_resolved} 个冲突")
# 4. 应用反思结果(更新记忆库)
# 4. Apply reflection results (update memory database)
memories_updated=await self._apply_reflection_results(solved_data)
execution_time = asyncio.get_event_loop().time() - start_time
@@ -302,7 +372,19 @@ class ReflectionEngine:
)
async def Translate(self, text):
# 翻译中文为英文
"""
Translate Chinese text to English
Uses the configured LLM to translate Chinese text to English with structured output.
Provides consistent translation format for reflection results.
Args:
text: Chinese text to be translated
Returns:
str: Translated English text
"""
# Translate Chinese to English
translation_messages = [
{
"role": "user",
@@ -316,6 +398,19 @@ class ReflectionEngine:
)
return response.data
async def extract_translation(self,data):
"""
Extract and translate reflection data to English
Processes reflection data structure and translates all Chinese content to English.
Handles nested data structures including memory verifications, quality assessments,
and reflection data while preserving the original structure.
Args:
data: Dictionary containing reflection data with Chinese content
Returns:
dict: Translated data structure with English content
"""
end_datas={}
end_datas['source_data']=await self.Translate(data['source_data'])
quality_assessments = []
@@ -350,6 +445,18 @@ class ReflectionEngine:
return end_datas
async def reflection_run(self):
"""
Execute reflection workflow with comprehensive data processing
Performs a complete reflection operation including conflict detection, resolution,
and result formatting. Supports both Chinese and English output based on
configuration settings.
Returns:
dict: Comprehensive reflection results including source data, memory verifications,
quality assessments, and reflection data. Results are translated to English
if language_type is set to 'en'.
"""
self._lazy_init()
start_time = time.time()
memory_verifies_flag = self.config.memory_verify
@@ -367,7 +474,7 @@ class ReflectionEngine:
result_data['source_data'] = "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
# 2. 检测冲突(基于事实的反思)
conflict_data = await self._detect_conflicts(databasets, source_data)
# 遍历数据提取字段
# Traverse data to extract fields
quality_assessments = []
memory_verifies = []
for item in conflict_data:
@@ -375,9 +482,9 @@ class ReflectionEngine:
memory_verifies.append(item['memory_verify'])
result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments
conflicts_found = 0 # 初始化为整数0而不是空字符串
conflicts_found = 0 # Initialize as integer 0 instead of empty string
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
# Clearn conflict_dataAnd memory_verifyquality_assessment
# Clean conflict_data, and memory_verify and quality_assessment
cleaned_conflict_data = []
for item in conflict_data:
cleaned_item = {
@@ -389,7 +496,7 @@ class ReflectionEngine:
for item in conflict_data:
cleaned_data = []
for row in item.get("data", []):
# 删除 created_at / expired_at
# Remove created_at / expired_at
cleaned_row = {
k: v
for k, v in row.items()
@@ -402,7 +509,7 @@ class ReflectionEngine:
}
cleaned_conflict_data_.append(cleaned_item)
print(cleaned_conflict_data_)
# 3. 解决冲突
# 3. Resolve conflicts
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
if not solved_data:
return ReflectionResult(
@@ -413,7 +520,7 @@ class ReflectionEngine:
)
reflexion_data = []
# 遍历数据提取reflexion字段
# Traverse data to extract reflexion fields
for item in solved_data:
if 'results' in item:
for result in item['results']:
@@ -431,15 +538,24 @@ class ReflectionEngine:
async def extract_fields_from_json(self):
"""从example.json中提取source_data和databasets字段"""
"""
Extract source_data and databasets fields from example.json
Reads reflection example data from the example.json file and extracts
the source data and database statements for testing and demonstration purposes.
Returns:
tuple: (source_data, databasets) extracted from the example file
Returns empty lists if file reading fails
"""
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
try:
# 读取JSON文件
# Read JSON file
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
data = json.loads(f.read())
# 提取memory_verify下的字段
# Extract fields under memory_verify
memory_verify = data.get("memory_verify", {})
source_data = memory_verify.get("source_data", [])
databasets = memory_verify.get("databasets", [])
@@ -451,15 +567,17 @@ class ReflectionEngine:
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
"""
获取反思数据
根据配置的反思范围获取需要反思的记忆数据。
Get reflection data from database
Retrieves memory data for reflection based on the configured reflection range.
Supports both partial (from retrieval results) and full (entire database) modes.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping data retrieval
Returns:
List[Any]: 反思数据列表
tuple: (reflexion_data, statement_data) containing memory data for reflection
Returns empty lists if query fails
"""
print("=== 获取反思数据 ===")
@@ -484,26 +602,29 @@ class ReflectionEngine:
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
检测冲突(基于事实的反思)
使用 LLM 分析记忆数据,检测其中的冲突。
Detect conflicts (fact-based reflection)
Uses LLM to analyze memory data and detect conflicts within the memories.
Performs comprehensive conflict detection including memory verification and
quality assessment based on configuration settings.
Args:
data: 待检测的记忆数据
data: Memory data to be analyzed for conflicts
statement_databasets: Statement database records for context
Returns:
List[Any]: 冲突记忆列表
List[Any]: List of detected conflicts with detailed analysis
"""
if not data:
return []
# 数据预处理:如果数据量太少,直接返回无冲突
# Data preprocessing: if data is too small, return no conflicts directly
if len(data) < 2:
logging.info("数据量不足,无需检测冲突")
return []
# 使用转换后的数据
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
# Use converted data
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
memory_verify = self.config.memory_verify
logging.info("====== 冲突检测开始 ======")
@@ -512,7 +633,7 @@ class ReflectionEngine:
language_type=self.config.language_type
try:
# 渲染冲突检测提示词
# Render conflict detection prompt
rendered_prompt = await self.render_evaluate_prompt_func(
data,
self.conflict_schema,
@@ -526,7 +647,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}]
logging.info(f"提示词长度: {len(rendered_prompt)}")
# 调用 LLM 进行冲突检测
# Call LLM for conflict detection
response = await self.llm_client.response_structured(
messages,
self.conflict_schema
@@ -539,7 +660,7 @@ class ReflectionEngine:
logging.error("LLM 冲突检测输出解析失败")
return []
# 标准化返回格式
# Standardize return format
if isinstance(response, BaseModel):
return [response.model_dump()]
elif hasattr(response, 'dict'):
@@ -553,15 +674,17 @@ class ReflectionEngine:
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
解决冲突
使用 LLM 对检测到的冲突进行反思和解决。
Resolve detected conflicts
Uses LLM to perform reflection and resolution on detected conflicts.
Processes conflicts in parallel for efficiency while respecting concurrency limits.
Args:
conflicts: 冲突列表
conflicts: List of conflicts to be resolved
statement_databasets: Statement database records for context
Returns:
List[Any]: 解决方案列表
List[Any]: List of resolution solutions with reflection results
"""
if not conflicts:
return []
@@ -570,12 +693,12 @@ class ReflectionEngine:
baseline = self.config.baseline
memory_verify = self.config.memory_verify
# 并行处理每个冲突
# Process each conflict in parallel
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
"""解决单个冲突"""
"""Resolve a single conflict"""
async with self._semaphore:
try:
# 渲染反思提示词
# Render reflection prompt
rendered_prompt = await self.render_reflexion_prompt_func(
[conflict],
self.reflexion_schema,
@@ -587,7 +710,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}]
# 调用 LLM 进行反思
# Call LLM for reflection
response = await self.llm_client.response_structured(
messages,
self.reflexion_schema
@@ -596,7 +719,7 @@ class ReflectionEngine:
if not response:
return None
# 标准化返回格式
# Standardize return format
if isinstance(response, BaseModel):
return response.model_dump()
elif hasattr(response, 'dict'):
@@ -610,11 +733,11 @@ class ReflectionEngine:
logging.warning(f"解决单个冲突失败: {e}")
return None
# 并发执行所有冲突解决任务
# Execute all conflict resolution tasks concurrently
tasks = [_resolve_one(conflict) for conflict in conflicts]
results = await asyncio.gather(*tasks, return_exceptions=False)
# 过滤掉失败的结果
# Filter out failed results
solved = [r for r in results if r is not None]
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
@@ -626,15 +749,16 @@ class ReflectionEngine:
solved_data: List[Dict[str, Any]]
) -> int:
"""
应用反思结果(更新记忆库)
将解决冲突后的记忆更新到 Neo4j 数据库中。
Apply reflection results (update memory database)
Updates the Neo4j database with resolved conflicts and reflection results.
Processes the solved data and applies changes to the memory storage system.
Args:
solved_data: 解决方案列表
solved_data: List of resolved conflict solutions with reflection data
Returns:
int: 成功更新的记忆数量
int: Number of successfully updated memory entries
"""
changes = extract_and_process_changes(solved_data)
success_count = await neo4j_data(changes)
@@ -642,80 +766,86 @@ class ReflectionEngine:
# 基于时间的反思方法
# Time-based reflection methods
async def time_based_reflection(
self,
host_id: uuid.UUID,
time_period: Optional[str] = None
) -> ReflectionResult:
"""
基于时间的反思
根据时间周期触发反思,检查在指定时间段内的记忆。
Time-based reflection
Triggers reflection based on time cycles, checking memories within
specified time periods. Uses the configured iteration period if
no specific time period is provided.
Args:
host_id: 主机ID
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
host_id: Host UUID identifier for scoping reflection
time_period: Time period (e.g., "three hours"), uses config value if not provided
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result
"""
period = time_period or self.config.iteration_period
logging.info(f"执行基于时间的反思,周期: {period}")
# 使用标准反思流程
# Use standard reflection workflow
return await self.execute_reflection(host_id)
# 基于事实的反思方法
# Fact-based reflection methods
async def fact_based_reflection(
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
基于事实的反思
检测记忆中的事实冲突并解决。
Fact-based reflection
Detects and resolves factual conflicts within memories. Analyzes
memory data for inconsistencies and contradictions that need resolution.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping reflection
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result
"""
logging.info("执行基于事实的反思")
# 使用标准反思流程
# Use standard reflection workflow
return await self.execute_reflection(host_id)
# 综合反思方法
# Comprehensive reflection methods
async def comprehensive_reflection(
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
综合反思
整合基于时间和基于事实的反思策略。
Comprehensive reflection
Integrates time-based and fact-based reflection strategies based on
the configured baseline. Supports hybrid approaches that combine
multiple reflection methodologies.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping reflection
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result combining
multiple strategies if using hybrid baseline
"""
logging.info("执行综合反思")
# 根据配置的基线选择反思策略
# Choose reflection strategy based on configured baseline
if self.config.baseline == ReflectionBaseline.TIME:
return await self.time_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.FACT:
return await self.fact_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.HYBRID:
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
time_result = await self.time_based_reflection(host_id)
fact_result = await self.fact_based_reflection(host_id)
# 合并结果
# Merge results
return ReflectionResult(
success=time_result.success and fact_result.success,
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",

View File

@@ -2,9 +2,17 @@ import json
def escape_lucene_query(query: str) -> str:
"""Escape Lucene special characters in a free-text query.
This prevents ParseException when using Neo4j full-text procedures.
"""
Escape special characters in Lucene queries
Prevents ParseException when using Neo4j full-text search procedures.
Escapes all Lucene reserved special characters and operators.
Args:
query: Original query string
Returns:
str: Escaped query string safe for Lucene search
"""
if query is None:
return ""
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
return s
def extract_plain_query(query_input: str) -> str:
"""Extract clean, plain-text query from various input forms.
"""
Extract clean plain-text query from various input forms
Handles the following cases:
- Strips surrounding quotes and whitespace
- If input looks like JSON, prefers the 'original' field
- Fallbacks to the raw string when parsing fails
- Falls back to raw string when parsing fails
- Handles dictionary-type input
- Best-effort unescape common escape characters
Args:
query_input: Query input in various forms (string, dict, etc.)
Returns:
str: Extracted plain-text query string
"""
if query_input is None:
return ""

View File

@@ -4,7 +4,13 @@ from datetime import datetime
def validate_date_format(date_str: str) -> bool:
"""
Validate if the date string is in the format YYYY-MM-DD.
Validate if date string conforms to YYYY-MM-DD format
Args:
date_str: Date string to validate
Returns:
bool: True if format is correct, False otherwise
"""
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
return bool(re.match(pattern, date_str))
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
def preprocess_date_string(date_str: str) -> str:
"""预处理日期字符串,处理特殊格式"""
"""
预处理日期字符串,处理特殊格式
处理以下特殊格式:
- 年份后直接跟月份没有分隔符的格式(如 "20259/28"
- 无分隔符的纯数字格式(如 "20251028", "251028"
- 混合分隔符,统一为 "-"
Args:
date_str: 原始日期字符串
Returns:
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD""YYYY-MM"
"""
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
def fallback_parse(date_str: str) -> str:
"""备选解析方案"""
"""
备选日期解析方案
当智能解析失败时,尝试使用预定义的日期格式进行解析。
支持多种常见的日期格式,包括:
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
- YYYYMMDD, YYMMDD
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
- YYYY-MM, YYYY/MM, YYYY.MM
Args:
date_str: 待解析的日期字符串
Returns:
str: 标准化后的日期字符串YYYY-MM-DD格式解析失败时返回原字符串
"""
# 尝试常见的日期格式[citation:4][citation:5]
formats_to_try = [

View File

@@ -2,15 +2,15 @@ import os
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any
# Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
memory_verify: bool = False, quality_assessment: bool = False,
statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
Returns:
Rendered prompt content as string
"""
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
Returns:
Rendered prompt content as a string.
"""
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("reflexion.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type)
baseline=baseline, memory_verify=memory_verify,
statement_databasets=statement_databasets, language_type=language_type)
return rendered_prompt

View File

@@ -1,23 +1,19 @@
from __future__ import annotations
import asyncio
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Dict, Optional, TypeVar
from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatTongyi
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI, OpenAI
from pydantic import BaseModel, Field
import httpx
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableSerializable
from pydantic import BaseModel, Field
T = TypeVar("T")
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI
elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
from langchain_aws import ChatBedrock, ChatBedrockConverse
return ChatBedrock
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput
from app.schemas.model_schema import ModelInfo
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -620,11 +621,12 @@ class BaseNode(ABC):
@staticmethod
async def process_message(
provider: str,
is_omni: bool,
api_config: ModelInfo,
content: str | dict | FileObject,
end_user_id: str,
enable_file=False
) -> list | str | None:
provider = api_config.provider
if isinstance(content, dict):
content = FileObject(
type=content.get("type"),
@@ -643,7 +645,7 @@ class BaseNode(ABC):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput(
type=content.type,
url=content.url,
@@ -653,7 +655,8 @@ class BaseNode(ABC):
)
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
[file_obj]
end_user_id,
[file_obj],
)
content.set_content(file_obj.get_content())
if message:

View File

@@ -5,7 +5,7 @@ from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
"output": VariableType.STRING
}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
result = []
for case in self.typed_config.cases:
expressions = []
for expression in case.expressions:
expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
})
return {
"cases": result
}
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:

View File

@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING
}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
}
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""

View File

@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context
from app.models import ModelType
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
# 在 Session 关闭前提取所有需要的数据
api_config = self.model_balance(config)
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
model_info = ModelInfo(
model_name=api_config.model_name,
model_type=ModelType(config.type),
api_key=api_config.api_key,
api_base=api_config.api_base,
provider=api_config.provider,
is_omni=api_config.is_omni,
capability=api_config.capability
)
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
model_name=model_info.model_name,
provider=model_info.provider,
api_key=model_info.api_key,
base_url=model_info.api_base,
extra_params=extra_params,
is_omni=is_omni
is_omni=model_info.is_omni
),
type=ModelType(model_type)
type=model_info.model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
logger.debug(
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
"role": "system",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(
model_info,
content,
user_id,
self.typed_config.vision,
)
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
if messages and messages[-1]["role"] == 'user':
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
{"role": message["role"], "content": file_content}
)
else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
message["content"] = await self.process_message(
model_info,
message["content"],
user_id,
self.typed_config.vision
)
history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
async for chunk in llm.astream(self.messages, stream_usage=True):
async for chunk in llm.astream(self.messages):
# 提取内容
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)

View File

@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
}
return None
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"text": self._render_template(self.typed_config.text, variable_pool),
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
"model_id": str(self.typed_config.model_id),
}
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params: