Merge branch 'refs/heads/develop' into feature/agent-tool_xjn

# Conflicts:
#	api/app/core/agent/langchain_agent.py
#	api/app/core/tools/mcp/client.py
This commit is contained in:
Timebomb2018
2026-04-01 15:27:34 +08:00
219 changed files with 4861 additions and 2599 deletions

View File

@@ -11,18 +11,14 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger()
@@ -226,10 +222,9 @@ class LangChainAgent:
Returns:
List[BaseMessage]: 消息列表
"""
messages = []
messages:list = [SystemMessage(content=self.system_prompt)]
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
@@ -320,12 +315,7 @@ class LangChainAgent:
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
files: Optional[List[Dict[str, Any]]] = None
) -> Dict[str, Any]:
"""执行对话
@@ -333,32 +323,12 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -445,9 +415,6 @@ class LangChainAgent:
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -478,12 +445,7 @@ class LangChainAgent:
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
files: Optional[List[Dict[str, Any]]] = None
) -> AsyncGenerator[str | int, None]:
"""执行流式对话
@@ -491,6 +453,7 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表
context: 上下文信息
files: 多模态文件
Yields:
str: 消息内容块
@@ -501,23 +464,6 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -527,17 +473,18 @@ class LangChainAgent:
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content = ''
try:
last_event = {}
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
last_event = event
chunk_count += 1
kind = event.get("event")
@@ -551,7 +498,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -562,18 +508,15 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
@@ -584,7 +527,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
@@ -595,22 +537,18 @@ class LangChainAgent:
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
elif isinstance(chunk, str):
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
@@ -620,17 +558,14 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise

View File

@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(
storage_type,
end_user_id,
@@ -118,7 +98,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
async def term_memory_save(end_user_id, strategy_type, scope):
"""
Save long-term memory data to database
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data) == scope:
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
if is_end_user_has_history:
end_user_visit_count, redis_messages = is_end_user_has_history
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
return
end_user_visit_count += 1
if end_user_visit_count < scope:
redis_messages.extend(langchain_messages)
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages
redis_messages.extend(langchain_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
write_message_task.delay(
end_user_id, # end_user_id: User ID
redis_messages, # message: JSON string format message list
config_id, # config_id: Configuration ID string
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
count_store.update_sessions_count(end_user_id, 0, [])
async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
return {
"is_same_event": False,

View File

@@ -1,49 +1,25 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
async def long_term_storage(
long_term_type: str,
langchain_messages: list,
memory_config_id: str,
end_user_id: str,
scope: int = 6
):
"""
Handle long-term memory storage with different strategies
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
memory_config_id: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
if langchain_messages is None:
langchain_messages = []
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
config_id=memory_config_id, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
# Dialogue window with 6 rounds of conversation
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
# Time-based strategy
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
# Aggregate judgment
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
"""
Write long-term memory with different storage types
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
messages: message list
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
await long_term_storage(long_term_type=CHUNK,
langchain_messages=messages,
memory_config_id=actual_config_id,
end_user_id=end_user_id,
scope=SCOPE)
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)

View File

@@ -3,8 +3,9 @@ import uuid
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp
)
logger = get_logger(__name__)
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -66,10 +67,10 @@ class RedisWriteStore:
})
result = pipe.execute()
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
logger.error(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -99,7 +100,7 @@ class RedisWriteStore:
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
@@ -108,16 +109,16 @@ class RedisWriteStore:
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
logger.error(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
@@ -144,7 +145,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
@@ -158,12 +159,12 @@ class RedisWriteStore:
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
@@ -173,23 +174,21 @@ class RedisWriteStore:
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
return False
def find_user_recent_sessions(self, userid: str,
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
@@ -203,11 +202,11 @@ class RedisWriteStore:
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -221,7 +220,7 @@ class RedisWriteStore:
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
@@ -230,15 +229,14 @@ class RedisWriteStore:
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
result_items = sort_and_limit_results(filtered_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
@@ -258,7 +256,7 @@ class RedisWriteStore:
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -278,7 +276,7 @@ class RedisCountStore:
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
self.uuid = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
@@ -295,26 +293,26 @@ class RedisCountStore:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"id": self.uuid,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
"""
通过 end_user_id 查询访问次数统计
@@ -327,7 +325,7 @@ class RedisCountStore:
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
@@ -335,35 +333,40 @@ class RedisCountStore:
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
messages: list[dict] = deserialize_messages(messages_str)
return int(count), messages
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
logger.error(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
"""
通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -378,39 +381,39 @@ class RedisCountStore:
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
logger.debug(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
@@ -428,7 +431,7 @@ class RedisCountStore:
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
@@ -451,9 +454,9 @@ class RedisSessionStore:
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
@@ -483,14 +486,14 @@ class RedisSessionStore:
})
result = pipe.execute()
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session] 保存会话失败: {e}")
logger.error(f"[save_session] 保存会话失败: {e}")
raise e
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
@@ -520,8 +523,8 @@ class RedisSessionStore:
sessions[sid] = self.get_session(sid)
return sessions
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
@@ -535,10 +538,10 @@ class RedisSessionStore:
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
@@ -556,21 +559,21 @@ class RedisSessionStore:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
@@ -591,7 +594,7 @@ class RedisSessionStore:
return bool(results[0])
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
@@ -632,7 +635,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
logger.debug("[delete_duplicate_sessions] 没有会话数据")
return 0
# 批量获取所有数据
@@ -678,7 +681,7 @@ class RedisSessionStore:
deleted_count += len(batch)
elapsed_time = time.time() - start_time
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count

View File

@@ -151,11 +151,6 @@ async def write(
# Step 3: Save all data to Neo4j database
step_start = time.time()
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
@@ -279,5 +274,21 @@ async def write(
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client):
try:
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
if underlying is None:
continue
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
inner = getattr(underlying, '_model', underlying)
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
http_client = getattr(inner, 'async_client', None)
if http_client is not None and hasattr(http_client, 'aclose'):
await http_client.aclose()
except Exception:
pass
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries
self.timeout = self.config.timeout
logger.info(
logger.debug(
f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)

View File

@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
type=type_
)
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
"""

View File

@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
logger = logging.getLogger(__name__)
def message_has_files(message: "ConversationMessage") -> bool:
"""检查消息是否包含文件。
Args:
message: 待检查的消息对象
Returns:
bool: 如果消息包含文件则返回 True否则返回 False
"""
return message.files and len(message.files) > 0
class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -128,7 +140,7 @@ class SemanticPruner:
1. 空消息
2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢""同学你好""明白了"
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢""同学你好""明白了"
5. 纯表情/标点
"""
t = message.msg.strip()
@@ -482,6 +494,11 @@ class SemanticPruner:
"""
to_delete_ids: set = set()
for m in msgs:
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
continue
# 填充检测优先:先判断是否为填充,再看 LLM 保护
if self._is_filler_message(m):
to_delete_ids.add(id(m))
@@ -549,6 +566,11 @@ class SemanticPruner:
to_delete_ids: set = set()
for m in msgs:
msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
continue
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
if self._is_filler_message(m):
@@ -801,6 +823,12 @@ class SemanticPruner:
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与分类
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
llm_protected_msgs.append((idx, m)) # 放入保护列表
continue
if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m))

View File

@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
list[StatementEntityEdge],
list[EntityEntityEdge],
list[PerceptualEdge],
dict
list[DialogData]
]:
"""
运行完整的知识提取流水线(优化版:并行执行)
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
statement_entity_edges,
entity_entity_edges,
dialog_data_list,
dedup_details,
) = await self._run_dedup_and_write_summary(
dialogue_nodes,
chunk_nodes,
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
dialog_data_list,
)
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
if not is_pilot_run:
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
logger.info(f"知识提取流水线运行完成({mode_str}")
return (
dialogue_nodes,
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
else:
first_alias = current_aliases[0].strip() if current_aliases else ""
if first_alias:
# 确保 first_alias 不是占位名称
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo(
end_user_id=end_user_uuid,
other_name=first_alias,
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'}
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
这个方法直接返回 LLM 提取的别名列表,不做任何修改
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I"
第一个别名将被用作 other_name。
Args:
entity_nodes: 实体节点列表
Returns:
别名列表(保持 LLM 提取的原始顺序)
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称
"""
USER_NAMES = {'用户', '', 'User', 'I'}
for entity in entity_nodes:
if getattr(entity, 'name', '').strip() in USER_NAMES:
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or []
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
return aliases
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
return filtered
return []
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
aliases = result[0].get('aliases') or []
if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return aliases
return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
def _resolve_other_name(
self,
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
决策规则:
- 为空 → 用本次对话第一个别名
- 为空或为占位名称 → 用本次对话第一个别名
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
- 否则 → 保持不变(返回 None
注意:返回值不允许是占位名称("用户""""User""I"
"""
if not current or not current.strip():
return current_aliases[0].strip() if current_aliases else None
# 当前值为空或为占位名称时,需要更新
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
if current not in neo4j_aliases:
return neo4j_aliases[0].strip() if neo4j_aliases else None
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
return None
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
list[DialogData],
dict
]:
"""
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
statement_chunk_edges,
dedup_statement_entity_edges,
dedup_entity_entity_edges,
dialog_data_list,
dedup_details,
)
final_entity_nodes = dedup_entity_nodes
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
final_entity_entity_edges = dedup_entity_entity_edges
else:
# 正式模式:执行完整的两阶段去重
result_tuple = await dedup_layers_and_merge_and_return(
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = await dedup_layers_and_merge_and_return(
dialogue_nodes,
chunk_nodes,
statement_nodes,
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
llm_client=self.llm_client,
)
# 解包返回值
(
_,
_,
_,
final_entity_nodes,
_,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = result_tuple
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
result_tuple = (
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dialog_data_list,
dedup_details,
)
logger.info(
f"去重后: {len(final_entity_nodes)} 个实体节点, "
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "

View File

@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %}
- 用户实体的 name 字段:使用 "用户" 或 "我"
- 用户的真实姓名:放入 aliases
- **🚨 禁止将 "用户"、"我" 放入 aliases 中aliases 只能包含用户的真实姓名、昵称等**
- 示例:
* "我叫李明" → name="用户", aliases=["李明"]
* ❌ 错误aliases=["用户", "李明"]"用户"不是真实姓名,禁止放入 aliases
* ❌ 错误aliases=["我", "李明"]"我"不是真实姓名,禁止放入 aliases
{% else %}
- User entity name field: use "User" or "I"
- User's real name: put in aliases
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
- Examples:
* "I'm John" → name="User", aliases=["John"]
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
{% endif %}

View File

@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
access_key_id: str,
access_key_secret: str,
bucket_name: str,
connect_timeout: int = 30,
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
):
"""
Initialize the OSSStorage backend.
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
access_key_id: The Aliyun access key ID.
access_key_secret: The Aliyun access key secret.
bucket_name: The name of the OSS bucket.
connect_timeout: Connection timeout in seconds (default: 30).
multipart_threshold: File size threshold for multipart upload (default: 10MB).
Raises:
StorageConfigError: If any required configuration is missing.
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
self.endpoint = endpoint
self.bucket_name = bucket_name
self.multipart_threshold = multipart_threshold
try:
auth = oss2.Auth(access_key_id, access_key_secret)
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
# 设置超时和重试
self.bucket = oss2.Bucket(
auth,
endpoint,
bucket_name,
connect_timeout=connect_timeout
)
logger.info(
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
)
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
if content_type:
headers["Content-Type"] = content_type
self.bucket.put_object(file_key, content, headers=headers if headers else None)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB per part
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers if headers else None)
logger.info(f"File uploaded to OSS successfully: {file_key}")
return file_key
except OssError as e:
logger.error(f"OSS error uploading file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to upload file to OSS: {e.message}",
message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to upload file to OSS: {e}",
message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
) -> int:
"""Upload from async stream to OSS. Returns total bytes written."""
buf = io.BytesIO()
headers = {"Content-Type": content_type} if content_type else None
upload_id = None
try:
# 收集流数据
total_size = 0
async for chunk in stream:
if not chunk:
continue
buf.write(chunk)
total_size += len(chunk)
content = buf.getvalue()
headers = {"Content-Type": content_type} if content_type else None
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
return len(content)
if not content:
raise StorageUploadError(
message="Empty stream content",
file_key=file_key,
)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
return total_size
except OssError as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"OSS error stream uploading file {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e.message}",
message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e}",
message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key,
cause=e,
)
finally:
buf.close()
async def download(self, file_key: str) -> bytes:
"""
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
except OssError as e:
logger.error(f"OSS error downloading file {file_key}: {e}")
raise StorageDownloadError(
message=f"Failed to download file from OSS: {e.message}",
message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to download file from OSS {file_key}: {e}")
raise StorageDownloadError(
message=f"Failed to download file from OSS: {e}",
message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
except OssError as e:
logger.error(f"OSS error deleting file {file_key}: {e}")
raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e.message}",
message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)
except Exception as e:
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e}",
message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key,
cause=e,
)

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y):
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
merged = dict(x)
for k, v in y.items():
merged[k] = merged.get(k, False) or v
return merged
def merge_looping_state(x, y):

View File

@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
def __getitem__(self, key):
return self._resolve(key)
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector:
"""变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod
def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
@@ -303,6 +347,16 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
var_type=var_type,
mut=False
)

View File

@@ -552,9 +552,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict
)
@@ -579,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition(
expression=expression,
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.lazy_namespace("sys")
)
@staticmethod

View File

@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
value = self.variable_pool.get_value(variable.value)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
variable.name: self.variable_pool.get_value(variable.value)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars

View File

@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input)
text = await svc.extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(

View File

@@ -1,19 +1,23 @@
import asyncio
import logging
import uuid
from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
def _get_existing_kb_ids(self, db, kb_ids):
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
query: query string
docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns:
list[UUID]: Final list of valid knowledge base IDs.
Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
"""
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
reranker = self.get_reranker_model()
# parameter validation
if not docs:
raise ValueError("retrieval chunks be empty")
if top_k <= 0:
raise ValueError("top_k must be a positive integer")
try:
# Convert to LangChain Document object
documents = [
Document(
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
metadata=doc.metadata or {} # Deal with possible None metadata
)
for doc in docs
]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
)
existing_ids.extend(items)
return existing_ids
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
kb_config.kb_id = child.id
self.knowledge_retrieval(db, query, rs, child, kb_config)
return
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
child_kb_config = kb_config.model_copy()
child_kb_config.kb_id = child.id
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
rs1_task = asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
return
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
rs.extend(
await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else:
rs.extend(sorted(
unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs:
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
final_rs = await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else:
final_rs = sorted(
rs,

View File

@@ -4,32 +4,33 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs."""
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""
Safely evaluate an expression using workflow variables.
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
# Build context for evaluation
context = {
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
}
context.update(conv_vars)
context["nodes"] = node_outputs
# context.update(conv_vars)
# context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e:
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e:
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
except Exception as e:
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod
def evaluate_bool(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""
Evaluate a boolean expression (for conditions).
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
expression, conv_var, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
for var in variables:
var_name = var.get("name", "")
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
errors.append(
f"Variable name '{var_name}' is not a valid Python identifier"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any]
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars

View File

@@ -1,7 +1,8 @@
"""
模板渲染器
Template Renderer
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
Provides safe template rendering using Jinja2, supporting variable references
and expressions.
"""
import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
from app.core.workflow.engine.variable_pool import LazyVariableDict
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
"""Return empty string instead of raising error when accessing undefined variables"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
"""Initialize renderer
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
strict: Whether to enable strict mode (raise error on undefined variables)
"""
self.strict = strict
self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
)
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
"""Normalize template syntax (convert numeric node reference to dict access)"""
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@@ -53,24 +54,24 @@ class TemplateRenderer:
def render(
self,
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str:
"""渲染模板
"""Render template
Args:
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出结果
system_vars: 系统变量
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Raises:
ValueError: 模板语法错误或变量未定义
ValueError: If template syntax is invalid or variables are undefined
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
@@ -80,122 +81,119 @@ class TemplateRenderer:
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... "Analysis result: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {"analyze": {"output": "positive sentiment"}},
... {}
... )
'分析结果: 正面情绪'
'Analysis result: positive sentiment'
"""
# 构建命名空间上下文
# Build namespace context
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}}
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
"node": node_outputs, # Node outputs: {{node.node_1.output}}
"sys": system_vars, # System variables: {{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
if node_outputs:
context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
# if conv_vars:
# context.update(conv_vars)
#
# context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
logger.error(f"Template syntax error: {template}, error: {e}")
raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
logger.error(f"Undefined variable in template: {template}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
logger.error(f"Template rendering error: {template}, error: {e}")
raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
"""Validate template syntax
Args:
template: 模板字符串
template: Template string
Returns:
错误列表,如果为空则验证通过
List of errors (empty if valid)
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
>>> renderer.validate("Hello {{var.name") # Missing closing tag
['Template syntax error: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
errors.append(f"Template syntax error: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
errors.append(f"Template validation failed: {e}")
return errors
# 全局渲染器实例(严格模式)
# Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any],
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
"""渲染模板(便捷函数)
"""Render template (convenience function)
Args:
strict: 严格模式
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出
system_vars: 系统变量
strict: Whether to use strict mode
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... "Analyze: {{var.text}}",
... {"text": "This is a text"},
... {},
... {}
... )
'请分析: 这是一段文本'
'Analyze: This is a text'
"""
renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
"""Validate template syntax (convenience function)
Args:
template: 模板字符串
template: Template string
Returns:
错误列表
List of errors
"""
return _strict_renderer.validate(template)